qws/
stream.rs

1use std::io;
2use std::io::ErrorKind::WouldBlock;
3#[cfg(any(feature = "ssl", feature = "nativetls"))]
4use std::mem::replace;
5use std::net::SocketAddr;
6
7use bytes::{Buf, BufMut};
8use mio::tcp::TcpStream;
9#[cfg(feature = "nativetls")]
10use native_tls::{
11    HandshakeError, MidHandshakeTlsStream as MidHandshakeSslStream, TlsStream as SslStream,
12};
13#[cfg(feature = "ssl")]
14use openssl::ssl::{ErrorCode as SslErrorCode, HandshakeError, MidHandshakeSslStream, SslStream};
15
16use result::{Error, Kind, Result};
17
18fn map_non_block<T>(res: io::Result<T>) -> io::Result<Option<T>> {
19    match res {
20        Ok(value) => Ok(Some(value)),
21        Err(err) => {
22            if let WouldBlock = err.kind() {
23                Ok(None)
24            } else {
25                Err(err)
26            }
27        }
28    }
29}
30
31pub trait TryReadBuf: io::Read {
32    fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<Option<usize>>
33    where
34        Self: Sized,
35    {
36        // Reads the length of the slice supplied by buf.mut_bytes into the buffer
37        // This is not guaranteed to consume an entire datagram or segment.
38        // If your protocol is msg based (instead of continuous stream) you should
39        // ensure that your buffer is large enough to hold an entire segment (1532 bytes if not jumbo
40        // frames)
41        let res = map_non_block(self.read(unsafe { buf.bytes_mut() }));
42
43        if let Ok(Some(cnt)) = res {
44            unsafe {
45                buf.advance_mut(cnt);
46            }
47        }
48
49        res
50    }
51}
52
53pub trait TryWriteBuf: io::Write {
54    fn try_write_buf<B: Buf>(&mut self, buf: &mut B) -> io::Result<Option<usize>>
55    where
56        Self: Sized,
57    {
58        let res = map_non_block(self.write(buf.bytes()));
59
60        if let Ok(Some(cnt)) = res {
61            buf.advance(cnt);
62        }
63
64        res
65    }
66}
67
68impl<T: io::Read> TryReadBuf for T {}
69impl<T: io::Write> TryWriteBuf for T {}
70
71use self::Stream::*;
72pub enum Stream {
73    Tcp(TcpStream),
74    #[cfg(any(feature = "ssl", feature = "nativetls"))]
75    Tls(TlsStream),
76}
77
78impl Stream {
79    pub fn tcp(stream: TcpStream) -> Stream {
80        Tcp(stream)
81    }
82
83    #[cfg(any(feature = "ssl", feature = "nativetls"))]
84    pub fn tls(stream: MidHandshakeSslStream<TcpStream>) -> Stream {
85        Tls(TlsStream::Handshake {
86            sock: stream,
87            negotiating: false,
88        })
89    }
90
91    #[cfg(any(feature = "ssl", feature = "nativetls"))]
92    pub fn tls_live(stream: SslStream<TcpStream>) -> Stream {
93        Tls(TlsStream::Live(stream))
94    }
95
96    #[cfg(any(feature = "ssl", feature = "nativetls"))]
97    pub fn is_tls(&self) -> bool {
98        match *self {
99            Tcp(_) => false,
100            Tls(_) => true,
101        }
102    }
103
104    pub fn evented(&self) -> &TcpStream {
105        match *self {
106            Tcp(ref sock) => sock,
107            #[cfg(any(feature = "ssl", feature = "nativetls"))]
108            Tls(ref inner) => inner.evented(),
109        }
110    }
111
112    pub fn is_negotiating(&self) -> bool {
113        match *self {
114            Tcp(_) => false,
115            #[cfg(any(feature = "ssl", feature = "nativetls"))]
116            Tls(ref inner) => inner.is_negotiating(),
117        }
118    }
119
120    pub fn clear_negotiating(&mut self) -> Result<()> {
121        match *self {
122            Tcp(_) => Err(Error::new(
123                Kind::Internal,
124                "Attempted to clear negotiating flag on non ssl connection.",
125            )),
126            #[cfg(any(feature = "ssl", feature = "nativetls"))]
127            Tls(ref mut inner) => inner.clear_negotiating(),
128        }
129    }
130
131    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
132        match *self {
133            Tcp(ref sock) => sock.peer_addr(),
134            #[cfg(any(feature = "ssl", feature = "nativetls"))]
135            Tls(ref inner) => inner.peer_addr(),
136        }
137    }
138
139    pub fn local_addr(&self) -> io::Result<SocketAddr> {
140        match *self {
141            Tcp(ref sock) => sock.local_addr(),
142            #[cfg(any(feature = "ssl", feature = "nativetls"))]
143            Tls(ref inner) => inner.local_addr(),
144        }
145    }
146}
147
148impl io::Read for Stream {
149    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
150        match *self {
151            Tcp(ref mut sock) => sock.read(buf),
152            #[cfg(any(feature = "ssl", feature = "nativetls"))]
153            Tls(TlsStream::Live(ref mut sock)) => sock.read(buf),
154            #[cfg(any(feature = "ssl", feature = "nativetls"))]
155            Tls(ref mut tls_stream) => {
156                trace!("Attempting to read ssl handshake.");
157                match replace(tls_stream, TlsStream::Upgrading) {
158                    TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(),
159                    TlsStream::Handshake {
160                        sock,
161                        mut negotiating,
162                    } => match sock.handshake() {
163                        Ok(mut sock) => {
164                            trace!("Completed SSL Handshake");
165                            let res = sock.read(buf);
166                            *tls_stream = TlsStream::Live(sock);
167                            res
168                        }
169                        #[cfg(feature = "ssl")]
170                        Err(HandshakeError::SetupFailure(err)) => {
171                            Err(io::Error::new(io::ErrorKind::Other, err))
172                        }
173                        #[cfg(feature = "ssl")]
174                        Err(HandshakeError::Failure(mid))
175                        | Err(HandshakeError::WouldBlock(mid)) => {
176                            if mid.error().code() == SslErrorCode::WANT_READ {
177                                negotiating = true;
178                            }
179                            let err = if let Some(io_error) = mid.error().io_error() {
180                                Err(io::Error::new(
181                                    io_error.kind(),
182                                    format!("{:?}", io_error.get_ref()),
183                                ))
184                            } else {
185                                Err(io::Error::new(
186                                    io::ErrorKind::Other,
187                                    format!("{}", mid.error()),
188                                ))
189                            };
190                            *tls_stream = TlsStream::Handshake {
191                                sock: mid,
192                                negotiating,
193                            };
194                            err
195                        }
196                        #[cfg(feature = "nativetls")]
197                        Err(HandshakeError::Interrupted(mid)) => {
198                            negotiating = true;
199                            *tls_stream = TlsStream::Handshake {
200                                sock: mid,
201                                negotiating: negotiating,
202                            };
203                            Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block"))
204                        }
205                        #[cfg(feature = "nativetls")]
206                        Err(HandshakeError::Failure(err)) => {
207                            Err(io::Error::new(io::ErrorKind::Other, format!("{}", err)))
208                        }
209                    },
210                }
211            }
212        }
213    }
214}
215
216impl io::Write for Stream {
217    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
218        match *self {
219            Tcp(ref mut sock) => sock.write(buf),
220            #[cfg(any(feature = "ssl", feature = "nativetls"))]
221            Tls(TlsStream::Live(ref mut sock)) => sock.write(buf),
222            #[cfg(any(feature = "ssl", feature = "nativetls"))]
223            Tls(ref mut tls_stream) => {
224                trace!("Attempting to write ssl handshake.");
225                match replace(tls_stream, TlsStream::Upgrading) {
226                    TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(),
227                    TlsStream::Handshake {
228                        sock,
229                        mut negotiating,
230                    } => match sock.handshake() {
231                        Ok(mut sock) => {
232                            trace!("Completed SSL Handshake");
233                            let res = sock.write(buf);
234                            *tls_stream = TlsStream::Live(sock);
235                            res
236                        }
237                        #[cfg(feature = "ssl")]
238                        Err(HandshakeError::SetupFailure(err)) => {
239                            Err(io::Error::new(io::ErrorKind::Other, err))
240                        }
241                        #[cfg(feature = "ssl")]
242                        Err(HandshakeError::Failure(mid))
243                        | Err(HandshakeError::WouldBlock(mid)) => {
244                            if mid.error().code() == SslErrorCode::WANT_READ {
245                                negotiating = true;
246                            } else {
247                                negotiating = false;
248                            }
249                            let err = if let Some(io_error) = mid.error().io_error() {
250                                Err(io::Error::new(
251                                    io_error.kind(),
252                                    format!("{:?}", io_error.get_ref()),
253                                ))
254                            } else {
255                                Err(io::Error::new(
256                                    io::ErrorKind::Other,
257                                    format!("{}", mid.error()),
258                                ))
259                            };
260                            *tls_stream = TlsStream::Handshake {
261                                sock: mid,
262                                negotiating,
263                            };
264                            err
265                        }
266                        #[cfg(feature = "nativetls")]
267                        Err(HandshakeError::Interrupted(mid)) => {
268                            negotiating = true;
269                            *tls_stream = TlsStream::Handshake {
270                                sock: mid,
271                                negotiating: negotiating,
272                            };
273                            Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block"))
274                        }
275                        #[cfg(feature = "nativetls")]
276                        Err(HandshakeError::Failure(err)) => {
277                            Err(io::Error::new(io::ErrorKind::Other, format!("{}", err)))
278                        }
279                    },
280                }
281            }
282        }
283    }
284
285    fn flush(&mut self) -> io::Result<()> {
286        match *self {
287            Tcp(ref mut sock) => sock.flush(),
288            #[cfg(any(feature = "ssl", feature = "nativetls"))]
289            Tls(TlsStream::Live(ref mut sock)) => sock.flush(),
290            #[cfg(any(feature = "ssl", feature = "nativetls"))]
291            Tls(TlsStream::Handshake { ref mut sock, .. }) => sock.get_mut().flush(),
292            #[cfg(any(feature = "ssl", feature = "nativetls"))]
293            Tls(TlsStream::Upgrading) => panic!("Tried to access actively upgrading TlsStream"),
294        }
295    }
296}
297
298#[cfg(any(feature = "ssl", feature = "nativetls"))]
299pub enum TlsStream {
300    Live(SslStream<TcpStream>),
301    Handshake {
302        sock: MidHandshakeSslStream<TcpStream>,
303        negotiating: bool,
304    },
305    Upgrading,
306}
307
308#[cfg(any(feature = "ssl", feature = "nativetls"))]
309impl TlsStream {
310    pub fn evented(&self) -> &TcpStream {
311        match *self {
312            TlsStream::Live(ref sock) => sock.get_ref(),
313            TlsStream::Handshake { ref sock, .. } => sock.get_ref(),
314            TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
315        }
316    }
317
318    pub fn is_negotiating(&self) -> bool {
319        match *self {
320            TlsStream::Live(_) => false,
321            TlsStream::Handshake {
322                sock: _,
323                negotiating,
324            } => negotiating,
325            TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
326        }
327    }
328
329    pub fn clear_negotiating(&mut self) -> Result<()> {
330        match *self {
331            TlsStream::Live(_) => Err(Error::new(
332                Kind::Internal,
333                "Attempted to clear negotiating flag on live ssl connection.",
334            )),
335            TlsStream::Handshake {
336                sock: _,
337                ref mut negotiating,
338            } => Ok(*negotiating = false),
339            TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
340        }
341    }
342
343    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
344        match *self {
345            TlsStream::Live(ref sock) => sock.get_ref().peer_addr(),
346            TlsStream::Handshake { ref sock, .. } => sock.get_ref().peer_addr(),
347            TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
348        }
349    }
350
351    pub fn local_addr(&self) -> io::Result<SocketAddr> {
352        match *self {
353            TlsStream::Live(ref sock) => sock.get_ref().local_addr(),
354            TlsStream::Handshake { ref sock, .. } => sock.get_ref().local_addr(),
355            TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
356        }
357    }
358}