1use std::io;
2use std::io::ErrorKind::WouldBlock;
3#[cfg(any(feature = "ssl", feature = "nativetls"))]
4use std::mem::replace;
5use std::net::SocketAddr;
6
7use bytes::{Buf, BufMut};
8use mio::tcp::TcpStream;
9#[cfg(feature = "nativetls")]
10use native_tls::{
11 HandshakeError, MidHandshakeTlsStream as MidHandshakeSslStream, TlsStream as SslStream,
12};
13#[cfg(feature = "ssl")]
14use openssl::ssl::{ErrorCode as SslErrorCode, HandshakeError, MidHandshakeSslStream, SslStream};
15
16use result::{Error, Kind, Result};
17
18fn map_non_block<T>(res: io::Result<T>) -> io::Result<Option<T>> {
19 match res {
20 Ok(value) => Ok(Some(value)),
21 Err(err) => {
22 if let WouldBlock = err.kind() {
23 Ok(None)
24 } else {
25 Err(err)
26 }
27 }
28 }
29}
30
31pub trait TryReadBuf: io::Read {
32 fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<Option<usize>>
33 where
34 Self: Sized,
35 {
36 let res = map_non_block(self.read(unsafe { buf.bytes_mut() }));
42
43 if let Ok(Some(cnt)) = res {
44 unsafe {
45 buf.advance_mut(cnt);
46 }
47 }
48
49 res
50 }
51}
52
53pub trait TryWriteBuf: io::Write {
54 fn try_write_buf<B: Buf>(&mut self, buf: &mut B) -> io::Result<Option<usize>>
55 where
56 Self: Sized,
57 {
58 let res = map_non_block(self.write(buf.bytes()));
59
60 if let Ok(Some(cnt)) = res {
61 buf.advance(cnt);
62 }
63
64 res
65 }
66}
67
68impl<T: io::Read> TryReadBuf for T {}
69impl<T: io::Write> TryWriteBuf for T {}
70
71use self::Stream::*;
72pub enum Stream {
73 Tcp(TcpStream),
74 #[cfg(any(feature = "ssl", feature = "nativetls"))]
75 Tls(TlsStream),
76}
77
78impl Stream {
79 pub fn tcp(stream: TcpStream) -> Stream {
80 Tcp(stream)
81 }
82
83 #[cfg(any(feature = "ssl", feature = "nativetls"))]
84 pub fn tls(stream: MidHandshakeSslStream<TcpStream>) -> Stream {
85 Tls(TlsStream::Handshake {
86 sock: stream,
87 negotiating: false,
88 })
89 }
90
91 #[cfg(any(feature = "ssl", feature = "nativetls"))]
92 pub fn tls_live(stream: SslStream<TcpStream>) -> Stream {
93 Tls(TlsStream::Live(stream))
94 }
95
96 #[cfg(any(feature = "ssl", feature = "nativetls"))]
97 pub fn is_tls(&self) -> bool {
98 match *self {
99 Tcp(_) => false,
100 Tls(_) => true,
101 }
102 }
103
104 pub fn evented(&self) -> &TcpStream {
105 match *self {
106 Tcp(ref sock) => sock,
107 #[cfg(any(feature = "ssl", feature = "nativetls"))]
108 Tls(ref inner) => inner.evented(),
109 }
110 }
111
112 pub fn is_negotiating(&self) -> bool {
113 match *self {
114 Tcp(_) => false,
115 #[cfg(any(feature = "ssl", feature = "nativetls"))]
116 Tls(ref inner) => inner.is_negotiating(),
117 }
118 }
119
120 pub fn clear_negotiating(&mut self) -> Result<()> {
121 match *self {
122 Tcp(_) => Err(Error::new(
123 Kind::Internal,
124 "Attempted to clear negotiating flag on non ssl connection.",
125 )),
126 #[cfg(any(feature = "ssl", feature = "nativetls"))]
127 Tls(ref mut inner) => inner.clear_negotiating(),
128 }
129 }
130
131 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
132 match *self {
133 Tcp(ref sock) => sock.peer_addr(),
134 #[cfg(any(feature = "ssl", feature = "nativetls"))]
135 Tls(ref inner) => inner.peer_addr(),
136 }
137 }
138
139 pub fn local_addr(&self) -> io::Result<SocketAddr> {
140 match *self {
141 Tcp(ref sock) => sock.local_addr(),
142 #[cfg(any(feature = "ssl", feature = "nativetls"))]
143 Tls(ref inner) => inner.local_addr(),
144 }
145 }
146}
147
148impl io::Read for Stream {
149 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
150 match *self {
151 Tcp(ref mut sock) => sock.read(buf),
152 #[cfg(any(feature = "ssl", feature = "nativetls"))]
153 Tls(TlsStream::Live(ref mut sock)) => sock.read(buf),
154 #[cfg(any(feature = "ssl", feature = "nativetls"))]
155 Tls(ref mut tls_stream) => {
156 trace!("Attempting to read ssl handshake.");
157 match replace(tls_stream, TlsStream::Upgrading) {
158 TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(),
159 TlsStream::Handshake {
160 sock,
161 mut negotiating,
162 } => match sock.handshake() {
163 Ok(mut sock) => {
164 trace!("Completed SSL Handshake");
165 let res = sock.read(buf);
166 *tls_stream = TlsStream::Live(sock);
167 res
168 }
169 #[cfg(feature = "ssl")]
170 Err(HandshakeError::SetupFailure(err)) => {
171 Err(io::Error::new(io::ErrorKind::Other, err))
172 }
173 #[cfg(feature = "ssl")]
174 Err(HandshakeError::Failure(mid))
175 | Err(HandshakeError::WouldBlock(mid)) => {
176 if mid.error().code() == SslErrorCode::WANT_READ {
177 negotiating = true;
178 }
179 let err = if let Some(io_error) = mid.error().io_error() {
180 Err(io::Error::new(
181 io_error.kind(),
182 format!("{:?}", io_error.get_ref()),
183 ))
184 } else {
185 Err(io::Error::new(
186 io::ErrorKind::Other,
187 format!("{}", mid.error()),
188 ))
189 };
190 *tls_stream = TlsStream::Handshake {
191 sock: mid,
192 negotiating,
193 };
194 err
195 }
196 #[cfg(feature = "nativetls")]
197 Err(HandshakeError::Interrupted(mid)) => {
198 negotiating = true;
199 *tls_stream = TlsStream::Handshake {
200 sock: mid,
201 negotiating: negotiating,
202 };
203 Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block"))
204 }
205 #[cfg(feature = "nativetls")]
206 Err(HandshakeError::Failure(err)) => {
207 Err(io::Error::new(io::ErrorKind::Other, format!("{}", err)))
208 }
209 },
210 }
211 }
212 }
213 }
214}
215
216impl io::Write for Stream {
217 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
218 match *self {
219 Tcp(ref mut sock) => sock.write(buf),
220 #[cfg(any(feature = "ssl", feature = "nativetls"))]
221 Tls(TlsStream::Live(ref mut sock)) => sock.write(buf),
222 #[cfg(any(feature = "ssl", feature = "nativetls"))]
223 Tls(ref mut tls_stream) => {
224 trace!("Attempting to write ssl handshake.");
225 match replace(tls_stream, TlsStream::Upgrading) {
226 TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(),
227 TlsStream::Handshake {
228 sock,
229 mut negotiating,
230 } => match sock.handshake() {
231 Ok(mut sock) => {
232 trace!("Completed SSL Handshake");
233 let res = sock.write(buf);
234 *tls_stream = TlsStream::Live(sock);
235 res
236 }
237 #[cfg(feature = "ssl")]
238 Err(HandshakeError::SetupFailure(err)) => {
239 Err(io::Error::new(io::ErrorKind::Other, err))
240 }
241 #[cfg(feature = "ssl")]
242 Err(HandshakeError::Failure(mid))
243 | Err(HandshakeError::WouldBlock(mid)) => {
244 if mid.error().code() == SslErrorCode::WANT_READ {
245 negotiating = true;
246 } else {
247 negotiating = false;
248 }
249 let err = if let Some(io_error) = mid.error().io_error() {
250 Err(io::Error::new(
251 io_error.kind(),
252 format!("{:?}", io_error.get_ref()),
253 ))
254 } else {
255 Err(io::Error::new(
256 io::ErrorKind::Other,
257 format!("{}", mid.error()),
258 ))
259 };
260 *tls_stream = TlsStream::Handshake {
261 sock: mid,
262 negotiating,
263 };
264 err
265 }
266 #[cfg(feature = "nativetls")]
267 Err(HandshakeError::Interrupted(mid)) => {
268 negotiating = true;
269 *tls_stream = TlsStream::Handshake {
270 sock: mid,
271 negotiating: negotiating,
272 };
273 Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block"))
274 }
275 #[cfg(feature = "nativetls")]
276 Err(HandshakeError::Failure(err)) => {
277 Err(io::Error::new(io::ErrorKind::Other, format!("{}", err)))
278 }
279 },
280 }
281 }
282 }
283 }
284
285 fn flush(&mut self) -> io::Result<()> {
286 match *self {
287 Tcp(ref mut sock) => sock.flush(),
288 #[cfg(any(feature = "ssl", feature = "nativetls"))]
289 Tls(TlsStream::Live(ref mut sock)) => sock.flush(),
290 #[cfg(any(feature = "ssl", feature = "nativetls"))]
291 Tls(TlsStream::Handshake { ref mut sock, .. }) => sock.get_mut().flush(),
292 #[cfg(any(feature = "ssl", feature = "nativetls"))]
293 Tls(TlsStream::Upgrading) => panic!("Tried to access actively upgrading TlsStream"),
294 }
295 }
296}
297
298#[cfg(any(feature = "ssl", feature = "nativetls"))]
299pub enum TlsStream {
300 Live(SslStream<TcpStream>),
301 Handshake {
302 sock: MidHandshakeSslStream<TcpStream>,
303 negotiating: bool,
304 },
305 Upgrading,
306}
307
308#[cfg(any(feature = "ssl", feature = "nativetls"))]
309impl TlsStream {
310 pub fn evented(&self) -> &TcpStream {
311 match *self {
312 TlsStream::Live(ref sock) => sock.get_ref(),
313 TlsStream::Handshake { ref sock, .. } => sock.get_ref(),
314 TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
315 }
316 }
317
318 pub fn is_negotiating(&self) -> bool {
319 match *self {
320 TlsStream::Live(_) => false,
321 TlsStream::Handshake {
322 sock: _,
323 negotiating,
324 } => negotiating,
325 TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
326 }
327 }
328
329 pub fn clear_negotiating(&mut self) -> Result<()> {
330 match *self {
331 TlsStream::Live(_) => Err(Error::new(
332 Kind::Internal,
333 "Attempted to clear negotiating flag on live ssl connection.",
334 )),
335 TlsStream::Handshake {
336 sock: _,
337 ref mut negotiating,
338 } => Ok(*negotiating = false),
339 TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
340 }
341 }
342
343 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
344 match *self {
345 TlsStream::Live(ref sock) => sock.get_ref().peer_addr(),
346 TlsStream::Handshake { ref sock, .. } => sock.get_ref().peer_addr(),
347 TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
348 }
349 }
350
351 pub fn local_addr(&self) -> io::Result<SocketAddr> {
352 match *self {
353 TlsStream::Live(ref sock) => sock.get_ref().local_addr(),
354 TlsStream::Handshake { ref sock, .. } => sock.get_ref().local_addr(),
355 TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
356 }
357 }
358}