fd_queue/
mio.rs

1// Copyright 2020 Steven Bosnick
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE-2.0 or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms
8
9//! An implementation of `EnqueueFd` and `DequeueFd` that is integrated with mio.
10
11use crate::{DequeueFd, EnqueueFd, QueueFullError};
12
13use std::convert::{TryFrom, TryInto};
14use std::io::{self, prelude::*, IoSlice, IoSliceMut};
15use std::net::Shutdown;
16use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
17use std::os::unix::net::{SocketAddr, UnixListener as StdUnixListner, UnixStream as StdUnixStream};
18use std::path::Path;
19
20use mio::{event::Evented, unix::EventedFd, Poll, PollOpt, Ready, Token};
21
22/// A non-blocking Unix stream socket with support for passing [`RawFd`][RawFd].
23///
24/// [RawFd]: https://doc.rust-lang.org/stable/std/os/unix/io/type.RawFd.html
25#[derive(Debug)]
26pub struct UnixStream {
27    inner: crate::UnixStream,
28}
29
30/// A non-blocking Unix domain socket server with support for passing [`RawFd`][RawFd].
31///
32/// [RawFd]: https://doc.rust-lang.org/stable/std/os/unix/io/type.RawFd.html
33#[derive(Debug)]
34pub struct UnixListener {
35    inner: crate::UnixListener,
36}
37
38// === impl UnixStream ===
39impl UnixStream {
40    /// Connects to the socket named by `path`.
41    ///
42    /// Note that this is synchronous.
43    pub fn connect(path: impl AsRef<Path>) -> io::Result<UnixStream> {
44        StdUnixStream::connect(path)?.try_into()
45    }
46
47    /// Creates an unnamed pair of connected sockets.
48    pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
49        let (sock1, sock2) = StdUnixStream::pair()?;
50
51        Ok((sock1.try_into()?, sock2.try_into()?))
52    }
53
54    /// Returns the socket address of the local half of this connection.
55    pub fn local_addr(&self) -> io::Result<SocketAddr> {
56        self.inner.local_addr()
57    }
58
59    /// Returns the socket address of the remote half of this connections.
60    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
61        self.inner.peer_addr()
62    }
63
64    /// Returns the value of the `SO_ERROR` option.
65    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
66        self.inner.take_error()
67    }
68
69    /// Shuts down the read, write, or both halves of the connection.
70    ///
71    /// This function will cause all pending and future I/O calls on the specified portions to
72    /// immediately return with an appropriate value (see the documentation of `Shutdown`).
73    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
74        self.inner.shutdown(how)
75    }
76}
77
78impl EnqueueFd for UnixStream {
79    fn enqueue(&mut self, fd: &impl AsRawFd) -> Result<(), QueueFullError> {
80        self.inner.enqueue(fd)
81    }
82}
83
84impl DequeueFd for UnixStream {
85    fn dequeue(&mut self) -> Option<RawFd> {
86        self.inner.dequeue()
87    }
88}
89
90impl Read for UnixStream {
91    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
92        self.inner.read(buf)
93    }
94
95    fn read_vectored(&mut self, bufs: &mut [IoSliceMut]) -> io::Result<usize> {
96        self.inner.read_vectored(bufs)
97    }
98}
99
100impl Write for UnixStream {
101    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
102        self.inner.write(buf)
103    }
104
105    fn write_vectored(&mut self, bufs: &[IoSlice]) -> io::Result<usize> {
106        self.inner.write_vectored(bufs)
107    }
108
109    fn flush(&mut self) -> io::Result<()> {
110        self.inner.flush()
111    }
112}
113
114impl Evented for UnixStream {
115    fn register(
116        &self,
117        registry: &Poll,
118        token: Token,
119        interests: Ready,
120        opts: PollOpt,
121    ) -> io::Result<()> {
122        EventedFd(&self.as_raw_fd()).register(registry, token, interests, opts)
123    }
124
125    fn reregister(
126        &self,
127        registry: &Poll,
128        token: Token,
129        interests: Ready,
130        opts: PollOpt,
131    ) -> io::Result<()> {
132        EventedFd(&self.as_raw_fd()).reregister(registry, token, interests, opts)
133    }
134
135    fn deregister(&self, registry: &Poll) -> io::Result<()> {
136        EventedFd(&self.as_raw_fd()).deregister(registry)
137    }
138}
139
140impl AsRawFd for UnixStream {
141    fn as_raw_fd(&self) -> RawFd {
142        self.inner.as_raw_fd()
143    }
144}
145
146/// Create a `UnixStream` from a `RawFd`.
147///
148/// This does not change the `RawFd` into non-blocking mode. It assumes that any such
149/// required change has already been done.
150impl FromRawFd for UnixStream {
151    unsafe fn from_raw_fd(fd: RawFd) -> Self {
152        let inner = StdUnixStream::from_raw_fd(fd);
153        UnixStream {
154            inner: inner.into(),
155        }
156    }
157}
158
159impl IntoRawFd for UnixStream {
160    fn into_raw_fd(self) -> RawFd {
161        self.inner.into_raw_fd()
162    }
163}
164
165impl TryFrom<StdUnixStream> for UnixStream {
166    type Error = io::Error;
167
168    fn try_from(inner: StdUnixStream) -> io::Result<UnixStream> {
169        inner.set_nonblocking(true)?;
170
171        Ok(UnixStream {
172            inner: inner.into(),
173        })
174    }
175}
176
177// === impl UnixListener ===
178
179impl UnixListener {
180    /// Creates a new `UnixListener` bound to the specific path.
181    ///
182    /// The listener will be set to non-blocking mode.
183    pub fn bind(path: impl AsRef<Path>) -> io::Result<UnixListener> {
184        StdUnixListner::bind(path)?.try_into()
185    }
186
187    /// Accepts a new incoming connection to this listener.
188    ///
189    /// The returned stream will be set to non-blocking mode.
190    pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> {
191        self.inner.accept().and_then(|(stream, addr)| {
192            stream.set_nonblocking(true)?;
193            Ok((UnixStream { inner: stream }, addr))
194        })
195    }
196
197    /// Returns the local socket address for this listener.
198    pub fn local_addr(&self) -> io::Result<SocketAddr> {
199        self.inner.local_addr()
200    }
201
202    /// Returns the value of the `SO_ERROR` option.
203    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
204        self.inner.take_error()
205    }
206}
207
208impl AsRawFd for UnixListener {
209    fn as_raw_fd(&self) -> RawFd {
210        self.inner.as_raw_fd()
211    }
212}
213
214/// Create a `UnixListener` from a `RawFd`.
215///
216/// This does not change the `RawFd` into non-blocking mode. It assumes that any such
217/// required change has already been done.
218impl FromRawFd for UnixListener {
219    unsafe fn from_raw_fd(fd: RawFd) -> Self {
220        let inner = StdUnixListner::from_raw_fd(fd);
221        UnixListener {
222            inner: inner.into(),
223        }
224    }
225}
226
227impl IntoRawFd for UnixListener {
228    fn into_raw_fd(self) -> RawFd {
229        self.inner.into_raw_fd()
230    }
231}
232
233impl Evented for UnixListener {
234    fn register(
235        &self,
236        registry: &Poll,
237        token: Token,
238        interests: Ready,
239        opts: PollOpt,
240    ) -> io::Result<()> {
241        EventedFd(&self.as_raw_fd()).register(registry, token, interests, opts)
242    }
243
244    fn reregister(
245        &self,
246        registry: &Poll,
247        token: Token,
248        interests: Ready,
249        opts: PollOpt,
250    ) -> io::Result<()> {
251        EventedFd(&self.as_raw_fd()).reregister(registry, token, interests, opts)
252    }
253
254    fn deregister(&self, registry: &Poll) -> io::Result<()> {
255        EventedFd(&self.as_raw_fd()).deregister(registry)
256    }
257}
258
259impl TryFrom<StdUnixListner> for UnixListener {
260    type Error = io::Error;
261
262    fn try_from(inner: StdUnixListner) -> Result<Self, Self::Error> {
263        inner.set_nonblocking(true)?;
264
265        Ok(UnixListener {
266            inner: inner.into(),
267        })
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    use std::io::ErrorKind;
276    use std::time::Duration;
277
278    use assert_matches::assert_matches;
279    use mio::{Events, Poll};
280
281    #[test]
282    fn stream_would_block_before_send() {
283        let mut buf = [0; 1024];
284
285        let (mut sut, _other) = UnixStream::pair().expect("Unable to create pair.");
286        let result = sut.read(buf.as_mut());
287
288        assert_matches!(result, Err(io) => assert_eq!(io.kind(), ErrorKind::WouldBlock));
289    }
290
291    #[test]
292    fn stream_is_ready_for_read_after_write() {
293        let poll = Poll::new().expect("Can't create poll.");
294        let mut events = Events::with_capacity(5);
295
296        let (mut sut, mut other) = UnixStream::pair().expect("Unable to create pair.");
297        poll.register(&mut sut, Token(0), Ready::readable(), PollOpt::edge())
298            .unwrap();
299        write_to_steam(&mut other);
300
301        let mut count = 0;
302        loop {
303            poll.poll(&mut events, Some(Duration::from_secs(1)))
304                .unwrap();
305            count += 1;
306            if count > 500 {
307                panic!("Too many spurious wakeups.");
308            }
309
310            for event in &events {
311                if event.token() == Token(0) && event.readiness().is_readable() {
312                    return;
313                }
314            }
315        }
316    }
317
318    fn write_to_steam(stream: &mut UnixStream) {
319        let mut count = 0;
320        loop {
321            count += 1;
322            if count > 500 {
323                panic!("Unable to write to steam after 500 tries");
324            }
325
326            match stream.write(b"hello".as_ref()) {
327                Ok(_) => return,
328                Err(ref e) if e.kind() == ErrorKind::WouldBlock => {}
329                Err(_) => panic!("Unable to write to stream"),
330            }
331        }
332    }
333}