1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
extern crate mio;
use std::io;

mod traits;

#[derive(Debug)]
pub struct Listener<T>
where
    T: traits::IntoListener,
{
    listener: T::Listener,
    poll: mio::Poll,
    stop_registration: mio::Registration,
    events: mio::Events,
}

const BREAK: mio::Token = mio::Token(0);
const LISTENER: mio::Token = mio::Token(1);

impl<T> Listener<T>
where
    T: traits::IntoListener,
{
    pub fn from_listener(listener: T) -> io::Result<(StopHandle, Self)>
    where
        T: traits::IntoListener,
    {
        let listener = listener.into_listener()?;
        let poll = mio::Poll::new()?;
        let (stop_registration, readiness) = mio::Registration::new2();
        poll.register(
            &stop_registration,
            BREAK,
            mio::Ready::readable(),
            mio::PollOpt::edge(),
        )?;
        poll.register(
            &listener,
            LISTENER,
            mio::Ready::readable(),
            mio::PollOpt::edge(),
        )?;
        let events = mio::Events::with_capacity(2);
        let stoppable_listener = Listener {
            listener,
            poll,
            stop_registration,
            events,
        };
        Ok((StopHandle(readiness), stoppable_listener))
    }
}

impl<'a, T> Iterator for &'a mut Listener<T>
where
    T: traits::IntoListener,
{
    type Item = io::Result<<T::Listener as traits::Listener>::Stream>;
    fn next(&mut self) -> Option<Self::Item> {
        use traits::Listener;
        loop {
            if self.events.iter().any(|event| event.token() == BREAK) {
                return None;
            }
            match self.listener.accept() {
                Ok(x) => return Some(Ok(x)),
                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
                Err(e) => return Some(Err(e)),
            }
            let poll_result = self.poll.poll(&mut self.events, None);
            if let Err(e) = poll_result {
                return Some(Err(e));
            }
        }
    }
}

#[derive(Debug, Clone)]
pub struct StopHandle(mio::SetReadiness);

impl StopHandle {
    pub fn stop_listening(&self) -> io::Result<()> {
        self.0.set_readiness(mio::Ready::readable())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn tcp() {
        use std::io::{Read, Write};
        let listener = std::net::TcpListener::bind("[::1]:0").unwrap();
        let addr = listener.local_addr().unwrap();
        let (stop_handle, mut listener) = Listener::from_listener(listener).unwrap();
        let listen_thread = std::thread::spawn(move || loop {
            match (&mut listener).next() {
                Some(Ok((mut socket, _))) => write!(socket, "hello world").unwrap(),
                Some(Err(e)) => panic!("{}", e),
                None => return,
            }
        });
        for _ in 0..10 {
            let stream = std::net::TcpStream::connect(addr).unwrap();
            for _ in stream.bytes() {}
        }
        stop_handle.stop_listening().unwrap();
        let thread_result = listen_thread.join().unwrap();
        assert_eq!(thread_result, ());
    }

    #[cfg(unix)]
    #[test]
    fn uds() {
        use std::io::{Read, Write};
        let mut attempt = 0;
        let listener = loop {
            attempt += 1;
            if attempt > 3 {
                panic!("unable to bind unix socket to temp file");
            }

            let tmp_filename = tempfile::NamedTempFile::new().unwrap().path().to_path_buf();
            if let Ok(listener) = std::os::unix::net::UnixListener::bind(tmp_filename) {
                break listener;
            }
        };
        let socket_path = listener
            .local_addr()
            .unwrap()
            .as_pathname()
            .unwrap()
            .to_path_buf();
        let (stop_handle, mut listener) = Listener::from_listener(listener).unwrap();
        let listen_thread = std::thread::spawn(move || loop {
            match (&mut listener).next() {
                Some(Ok((mut socket, _))) => write!(socket, "hello world").unwrap(),
                Some(Err(e)) => panic!("{}", e),
                None => return,
            }
        });
        for _ in 0..10 {
            let stream = std::os::unix::net::UnixStream::connect(&socket_path).unwrap();
            for _ in stream.bytes() {}
        }
        stop_handle.stop_listening().unwrap();
        let thread_result = listen_thread.join().unwrap();
        assert_eq!(thread_result, ());
    }
}