mio/net/tcp/
stream.rs

1use std::fmt;
2use std::io::{self, IoSlice, IoSliceMut, Read, Write};
3use std::net::{self, Shutdown, SocketAddr};
4#[cfg(any(unix, target_os = "wasi"))]
5use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
6// TODO: once <https://github.com/rust-lang/rust/issues/126198> is fixed this
7// can use `std::os::fd` and be merged with the above.
8#[cfg(target_os = "hermit")]
9use std::os::hermit::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
10#[cfg(windows)]
11use std::os::windows::io::{
12    AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
13};
14
15use crate::io_source::IoSource;
16#[cfg(not(target_os = "wasi"))]
17use crate::sys::tcp::{connect, new_for_addr};
18use crate::{event, Interest, Registry, Token};
19
20/// A non-blocking TCP stream between a local socket and a remote socket.
21///
22/// The socket will be closed when the value is dropped.
23///
24/// # Examples
25///
26#[cfg_attr(feature = "os-poll", doc = "```")]
27#[cfg_attr(not(feature = "os-poll"), doc = "```ignore")]
28/// # use std::net::{TcpListener, SocketAddr};
29/// # use std::error::Error;
30/// #
31/// # fn main() -> Result<(), Box<dyn Error>> {
32/// let address: SocketAddr = "127.0.0.1:0".parse()?;
33/// let listener = TcpListener::bind(address)?;
34/// use mio::{Events, Interest, Poll, Token};
35/// use mio::net::TcpStream;
36/// use std::time::Duration;
37///
38/// let mut stream = TcpStream::connect(listener.local_addr()?)?;
39///
40/// let mut poll = Poll::new()?;
41/// let mut events = Events::with_capacity(128);
42///
43/// // Register the socket with `Poll`
44/// poll.registry().register(&mut stream, Token(0), Interest::WRITABLE)?;
45///
46/// poll.poll(&mut events, Some(Duration::from_millis(100)))?;
47///
48/// // The socket might be ready at this point
49/// #     Ok(())
50/// # }
51/// ```
52pub struct TcpStream {
53    inner: IoSource<net::TcpStream>,
54}
55
56impl TcpStream {
57    /// Create a new TCP stream and issue a non-blocking connect to the
58    /// specified address.
59    ///
60    /// # Notes
61    ///
62    /// The returned `TcpStream` may not be connected (and thus usable), unlike
63    /// the API found in `std::net::TcpStream`. Because Mio issues a
64    /// *non-blocking* connect it will not block the thread and instead return
65    /// an unconnected `TcpStream`.
66    ///
67    /// Ensuring the returned stream is connected is surprisingly complex when
68    /// considering cross-platform support. Doing this properly should follow
69    /// the steps below, an example implementation can be found
70    /// [here](https://github.com/Thomasdezeeuw/heph/blob/0c4f1ab3eaf08bea1d65776528bfd6114c9f8374/src/net/tcp/stream.rs#L560-L622).
71    ///
72    ///  1. Call `TcpStream::connect`
73    ///  2. Register the returned stream with at least [write interest].
74    ///  3. Wait for a (writable) event.
75    ///  4. Check `TcpStream::take_error`. If it returns an error, then
76    ///     something went wrong. If it returns `Ok(None)`, then proceed to
77    ///     step 5.
78    ///  5. Check `TcpStream::peer_addr`. If it returns `libc::EINPROGRESS` or
79    ///     `ErrorKind::NotConnected` it means the stream is not yet connected,
80    ///     go back to step 3. If it returns an address it means the stream is
81    ///     connected, go to step 6. If another error is returned something
82    ///     went wrong.
83    ///  6. Now the stream can be used.
84    ///
85    /// This may return a `WouldBlock` in which case the socket connection
86    /// cannot be completed immediately, it usually means there are insufficient
87    /// entries in the routing cache.
88    ///
89    /// [write interest]: Interest::WRITABLE
90    #[cfg(not(target_os = "wasi"))]
91    pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
92        let socket = new_for_addr(addr)?;
93        #[cfg(any(unix, target_os = "hermit"))]
94        let stream = unsafe { TcpStream::from_raw_fd(socket) };
95        #[cfg(windows)]
96        let stream = unsafe { TcpStream::from_raw_socket(socket as _) };
97        connect(&stream.inner, addr)?;
98        Ok(stream)
99    }
100
101    /// Creates a new `TcpStream` from a standard `net::TcpStream`.
102    ///
103    /// This function is intended to be used to wrap a TCP stream from the
104    /// standard library in the Mio equivalent. The conversion assumes nothing
105    /// about the underlying stream; it is left up to the user to set it in
106    /// non-blocking mode.
107    ///
108    /// # Note
109    ///
110    /// The TCP stream here will not have `connect` called on it, so it
111    /// should already be connected via some other means (be it manually, or
112    /// the standard library).
113    pub fn from_std(stream: net::TcpStream) -> TcpStream {
114        TcpStream {
115            inner: IoSource::new(stream),
116        }
117    }
118
119    /// Returns the socket address of the remote peer of this TCP connection.
120    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
121        self.inner.peer_addr()
122    }
123
124    /// Returns the socket address of the local half of this TCP connection.
125    pub fn local_addr(&self) -> io::Result<SocketAddr> {
126        self.inner.local_addr()
127    }
128
129    /// Shuts down the read, write, or both halves of this connection.
130    ///
131    /// This function will cause all pending and future I/O on the specified
132    /// portions to return immediately with an appropriate value (see the
133    /// documentation of `Shutdown`).
134    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
135        self.inner.shutdown(how)
136    }
137
138    /// Sets the value of the `TCP_NODELAY` option on this socket.
139    ///
140    /// If set, this option disables the Nagle algorithm. This means that
141    /// segments are always sent as soon as possible, even if there is only a
142    /// small amount of data. When not set, data is buffered until there is a
143    /// sufficient amount to send out, thereby avoiding the frequent sending of
144    /// small packets.
145    ///
146    /// # Notes
147    ///
148    /// On Windows make sure the stream is connected before calling this method,
149    /// by receiving an (writable) event. Trying to set `nodelay` on an
150    /// unconnected `TcpStream` is unspecified behavior.
151    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
152        self.inner.set_nodelay(nodelay)
153    }
154
155    /// Gets the value of the `TCP_NODELAY` option on this socket.
156    ///
157    /// For more information about this option, see [`set_nodelay`][link].
158    ///
159    /// [link]: #method.set_nodelay
160    ///
161    /// # Notes
162    ///
163    /// On Windows make sure the stream is connected before calling this method,
164    /// by receiving an (writable) event. Trying to get `nodelay` on an
165    /// unconnected `TcpStream` is unspecified behavior.
166    pub fn nodelay(&self) -> io::Result<bool> {
167        self.inner.nodelay()
168    }
169
170    /// Sets the value for the `IP_TTL` option on this socket.
171    ///
172    /// This value sets the time-to-live field that is used in every packet sent
173    /// from this socket.
174    ///
175    /// # Notes
176    ///
177    /// On Windows make sure the stream is connected before calling this method,
178    /// by receiving an (writable) event. Trying to set `ttl` on an
179    /// unconnected `TcpStream` is unspecified behavior.
180    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
181        self.inner.set_ttl(ttl)
182    }
183
184    /// Gets the value of the `IP_TTL` option for this socket.
185    ///
186    /// For more information about this option, see [`set_ttl`][link].
187    ///
188    /// # Notes
189    ///
190    /// On Windows make sure the stream is connected before calling this method,
191    /// by receiving an (writable) event. Trying to get `ttl` on an
192    /// unconnected `TcpStream` is unspecified behavior.
193    ///
194    /// [link]: #method.set_ttl
195    pub fn ttl(&self) -> io::Result<u32> {
196        self.inner.ttl()
197    }
198
199    /// Get the value of the `SO_ERROR` option on this socket.
200    ///
201    /// This will retrieve the stored error in the underlying socket, clearing
202    /// the field in the process. This can be useful for checking errors between
203    /// calls.
204    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
205        self.inner.take_error()
206    }
207
208    /// Receives data on the socket from the remote address to which it is
209    /// connected, without removing that data from the queue. On success,
210    /// returns the number of bytes peeked.
211    ///
212    /// Successive calls return the same data. This is accomplished by passing
213    /// `MSG_PEEK` as a flag to the underlying recv system call.
214    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
215        // Need to re-register if `peek` returns `WouldBlock`
216        // to ensure the socket will receive more events once it is ready again.
217        self.inner.do_io(|inner| inner.peek(buf))
218    }
219
220    /// Execute an I/O operation ensuring that the socket receives more events
221    /// if it hits a [`WouldBlock`] error.
222    ///
223    /// # Notes
224    ///
225    /// This method is required to be called for **all** I/O operations to
226    /// ensure the user will receive events once the socket is ready again after
227    /// returning a [`WouldBlock`] error.
228    ///
229    /// [`WouldBlock`]: io::ErrorKind::WouldBlock
230    ///
231    /// # Examples
232    ///
233    #[cfg_attr(unix, doc = "```no_run")]
234    #[cfg_attr(windows, doc = "```ignore")]
235    /// # use std::error::Error;
236    /// #
237    /// # fn main() -> Result<(), Box<dyn Error>> {
238    /// use std::io;
239    /// #[cfg(any(unix, target_os = "wasi"))]
240    /// use std::os::fd::AsRawFd;
241    /// #[cfg(windows)]
242    /// use std::os::windows::io::AsRawSocket;
243    /// use mio::net::TcpStream;
244    ///
245    /// let address = "127.0.0.1:8080".parse().unwrap();
246    /// let stream = TcpStream::connect(address)?;
247    ///
248    /// // Wait until the stream is readable...
249    ///
250    /// // Read from the stream using a direct libc call, of course the
251    /// // `io::Read` implementation would be easier to use.
252    /// let mut buf = [0; 512];
253    /// let n = stream.try_io(|| {
254    ///     let buf_ptr = &mut buf as *mut _ as *mut _;
255    ///     #[cfg(unix)]
256    ///     let res = unsafe { libc::recv(stream.as_raw_fd(), buf_ptr, buf.len(), 0) };
257    ///     #[cfg(windows)]
258    ///     let res = unsafe { libc::recvfrom(stream.as_raw_socket() as usize, buf_ptr, buf.len() as i32, 0, std::ptr::null_mut(), std::ptr::null_mut()) };
259    ///     if res != -1 {
260    ///         Ok(res as usize)
261    ///     } else {
262    ///         // If EAGAIN or EWOULDBLOCK is set by libc::recv, the closure
263    ///         // should return `WouldBlock` error.
264    ///         Err(io::Error::last_os_error())
265    ///     }
266    /// })?;
267    /// eprintln!("read {} bytes", n);
268    /// # Ok(())
269    /// # }
270    /// ```
271    pub fn try_io<F, T>(&self, f: F) -> io::Result<T>
272    where
273        F: FnOnce() -> io::Result<T>,
274    {
275        self.inner.do_io(|_| f())
276    }
277}
278
279impl Read for TcpStream {
280    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
281        self.inner.do_io(|mut inner| inner.read(buf))
282    }
283
284    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
285        self.inner.do_io(|mut inner| inner.read_vectored(bufs))
286    }
287}
288
289impl Read for &'_ TcpStream {
290    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
291        self.inner.do_io(|mut inner| inner.read(buf))
292    }
293
294    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
295        self.inner.do_io(|mut inner| inner.read_vectored(bufs))
296    }
297}
298
299impl Write for TcpStream {
300    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
301        self.inner.do_io(|mut inner| inner.write(buf))
302    }
303
304    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
305        self.inner.do_io(|mut inner| inner.write_vectored(bufs))
306    }
307
308    fn flush(&mut self) -> io::Result<()> {
309        self.inner.do_io(|mut inner| inner.flush())
310    }
311}
312
313impl Write for &'_ TcpStream {
314    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
315        self.inner.do_io(|mut inner| inner.write(buf))
316    }
317
318    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
319        self.inner.do_io(|mut inner| inner.write_vectored(bufs))
320    }
321
322    fn flush(&mut self) -> io::Result<()> {
323        self.inner.do_io(|mut inner| inner.flush())
324    }
325}
326
327impl event::Source for TcpStream {
328    fn register(
329        &mut self,
330        registry: &Registry,
331        token: Token,
332        interests: Interest,
333    ) -> io::Result<()> {
334        self.inner.register(registry, token, interests)
335    }
336
337    fn reregister(
338        &mut self,
339        registry: &Registry,
340        token: Token,
341        interests: Interest,
342    ) -> io::Result<()> {
343        self.inner.reregister(registry, token, interests)
344    }
345
346    fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
347        self.inner.deregister(registry)
348    }
349}
350
351impl fmt::Debug for TcpStream {
352    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353        self.inner.fmt(f)
354    }
355}
356
357#[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
358impl IntoRawFd for TcpStream {
359    fn into_raw_fd(self) -> RawFd {
360        self.inner.into_inner().into_raw_fd()
361    }
362}
363
364#[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
365impl AsRawFd for TcpStream {
366    fn as_raw_fd(&self) -> RawFd {
367        self.inner.as_raw_fd()
368    }
369}
370
371#[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
372impl FromRawFd for TcpStream {
373    /// Converts a `RawFd` to a `TcpStream`.
374    ///
375    /// # Notes
376    ///
377    /// The caller is responsible for ensuring that the socket is in
378    /// non-blocking mode.
379    unsafe fn from_raw_fd(fd: RawFd) -> TcpStream {
380        TcpStream::from_std(FromRawFd::from_raw_fd(fd))
381    }
382}
383
384#[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
385impl From<TcpStream> for OwnedFd {
386    fn from(tcp_stream: TcpStream) -> Self {
387        tcp_stream.inner.into_inner().into()
388    }
389}
390
391#[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
392impl AsFd for TcpStream {
393    fn as_fd(&self) -> BorrowedFd<'_> {
394        self.inner.as_fd()
395    }
396}
397
398#[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
399impl From<OwnedFd> for TcpStream {
400    /// Converts a `RawFd` to a `TcpStream`.
401    ///
402    /// # Notes
403    ///
404    /// The caller is responsible for ensuring that the socket is in
405    /// non-blocking mode.
406    fn from(fd: OwnedFd) -> Self {
407        TcpStream::from_std(From::from(fd))
408    }
409}
410
411#[cfg(windows)]
412impl IntoRawSocket for TcpStream {
413    fn into_raw_socket(self) -> RawSocket {
414        self.inner.into_inner().into_raw_socket()
415    }
416}
417
418#[cfg(windows)]
419impl AsRawSocket for TcpStream {
420    fn as_raw_socket(&self) -> RawSocket {
421        self.inner.as_raw_socket()
422    }
423}
424
425#[cfg(windows)]
426impl FromRawSocket for TcpStream {
427    /// Converts a `RawSocket` to a `TcpStream`.
428    ///
429    /// # Notes
430    ///
431    /// The caller is responsible for ensuring that the socket is in
432    /// non-blocking mode.
433    unsafe fn from_raw_socket(socket: RawSocket) -> TcpStream {
434        TcpStream::from_std(FromRawSocket::from_raw_socket(socket))
435    }
436}
437
438#[cfg(windows)]
439impl From<TcpStream> for OwnedSocket {
440    fn from(tcp_stream: TcpStream) -> Self {
441        tcp_stream.inner.into_inner().into()
442    }
443}
444
445#[cfg(windows)]
446impl AsSocket for TcpStream {
447    fn as_socket(&self) -> BorrowedSocket<'_> {
448        self.inner.as_socket()
449    }
450}
451
452#[cfg(windows)]
453impl From<OwnedSocket> for TcpStream {
454    /// Converts a `RawSocket` to a `TcpStream`.
455    ///
456    /// # Notes
457    ///
458    /// The caller is responsible for ensuring that the socket is in
459    /// non-blocking mode.
460    fn from(socket: OwnedSocket) -> Self {
461        TcpStream::from_std(From::from(socket))
462    }
463}
464
465impl From<TcpStream> for net::TcpStream {
466    fn from(stream: TcpStream) -> Self {
467        // Safety: This is safe since we are extracting the raw fd from a well-constructed
468        // mio::net::TcpStream which ensures that we actually pass in a valid file
469        // descriptor/socket
470        unsafe {
471            #[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
472            {
473                net::TcpStream::from_raw_fd(stream.into_raw_fd())
474            }
475            #[cfg(windows)]
476            {
477                net::TcpStream::from_raw_socket(stream.into_raw_socket())
478            }
479        }
480    }
481}