1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
use crate::{tcp, Error, Result}; use std::{ convert::TryFrom, fmt, io::{self, IoSlice, IoSliceMut, Read, Write}, }; pub struct TcpStream(Inner); enum Inner { Connected(tcp::TcpStream), Handshaking(Option<tcp::MidHandshakeTlsStream>), } impl TryFrom<tcp::HandshakeResult> for TcpStream { type Error = Error; fn try_from(result: tcp::HandshakeResult) -> Result<Self> { Ok(Self(match result { Ok(stream) => Inner::Connected(stream), Err(handshaker) => { Inner::Handshaking(Some(handshaker.into_mid_handshake_tls_stream()?)) } })) } } impl TcpStream { pub(crate) fn inner(&self) -> &tcp::TcpStream { match self.0 { Inner::Connected(ref stream) => stream, Inner::Handshaking(ref handshaker) => handshaker.as_ref().unwrap().get_ref(), } } pub(crate) fn inner_mut(&mut self) -> &mut tcp::TcpStream { match self.0 { Inner::Connected(ref mut stream) => stream, Inner::Handshaking(ref mut handshaker) => handshaker.as_mut().unwrap().get_mut(), } } pub(crate) fn is_handshaking(&self) -> bool { if let Inner::Handshaking(_) = self.0 { true } else { false } } pub(crate) fn handshake(&mut self) -> Result<()> { if let Inner::Handshaking(ref mut handshaker) = self.0 { match handshaker.take().unwrap().handshake() { Ok(stream) => self.0 = Inner::Connected(stream), Err(error) => { self.0 = Inner::Handshaking(Some(error.into_mid_handshake_tls_stream()?)) } } } Ok(()) } } macro_rules! fwd_impl { ($self:ident, $method:ident, $($args:expr),*) => { match $self.0 { Inner::Connected(ref mut inner) => inner.$method($($args),*), Inner::Handshaking(ref mut inner) => inner.as_mut().unwrap().get_mut().$method($($args),*), } }; } impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fwd_impl!(self, read, buf) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> { fwd_impl!(self, read_vectored, bufs) } fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> { fwd_impl!(self, read_to_end, buf) } fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> { fwd_impl!(self, read_to_string, buf) } fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { fwd_impl!(self, read_exact, buf) } } impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fwd_impl!(self, write, buf) } fn flush(&mut self) -> io::Result<()> { fwd_impl!(self, flush,) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> { fwd_impl!(self, write_vectored, bufs) } fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { fwd_impl!(self, write_all, buf) } fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> { fwd_impl!(self, write_fmt, fmt) } }