mill_io/
reactor.rs

1use crate::{
2    error::Result,
3    poll::PollHandle,
4    thread_pool::{ComputePoolMetrics, ComputeThreadPool, TaskPriority, ThreadPool},
5};
6use mio::{event::Event, Events};
7use std::{
8    sync::{
9        atomic::{AtomicBool, Ordering},
10        Arc, RwLock,
11    },
12    time::Duration,
13};
14
15pub const DEFAULT_EVENTS_CAPACITY: usize = 1024;
16pub const DEFAULT_POLL_TIMEOUT_MS: u64 = 150;
17
18pub struct Reactor {
19    pub(crate) poll_handle: PollHandle,
20    events: Arc<RwLock<Events>>,
21    pool: ThreadPool,
22    compute_pool: ComputeThreadPool,
23    running: AtomicBool,
24    poll_timeout_ms: u64,
25}
26
27impl Default for Reactor {
28    fn default() -> Self {
29        Self {
30            poll_handle: PollHandle::new().unwrap(),
31            events: Arc::new(RwLock::new(Events::with_capacity(DEFAULT_EVENTS_CAPACITY))),
32            pool: ThreadPool::default(),
33            compute_pool: ComputeThreadPool::new(
34                std::thread::available_parallelism()
35                    .map(|n| n.get())
36                    .unwrap_or(4),
37            ),
38            running: AtomicBool::new(false),
39            poll_timeout_ms: DEFAULT_POLL_TIMEOUT_MS,
40        }
41    }
42}
43
44impl Reactor {
45    pub fn new(pool_size: usize, events_capacity: usize, poll_timeout_ms: u64) -> Result<Self> {
46        Ok(Self {
47            poll_handle: PollHandle::new()?,
48            events: Arc::new(RwLock::new(Events::with_capacity(events_capacity))),
49            pool: ThreadPool::new(pool_size),
50            compute_pool: ComputeThreadPool::default(),
51            running: AtomicBool::new(false),
52            poll_timeout_ms,
53        })
54    }
55
56    pub fn run(&self) -> Result<()> {
57        self.running.store(true, Ordering::SeqCst);
58
59        while self.running.load(Ordering::SeqCst) {
60            let _ = self.poll_handle.poll(
61                &mut self.events.write().unwrap(),
62                Some(Duration::from_millis(self.poll_timeout_ms)),
63            )?;
64
65            for event in self.events.read().unwrap().iter() {
66                self.dispatch_event(event.clone())?;
67            }
68        }
69        Ok(())
70    }
71
72    pub fn get_shutdown_handle(&self) -> ShutdownHandle<'_> {
73        ShutdownHandle {
74            running: &self.running,
75            poll_handle: &self.poll_handle,
76        }
77    }
78
79    pub fn dispatch_event(&self, event: Event) -> Result<()> {
80        let token = event.token();
81        let is_readable = event.is_readable();
82        let is_writable = event.is_writable();
83
84        let registry = self.poll_handle.get_registery();
85
86        self.pool.exec(move || {
87            let entry = registry.get(&token);
88            if let Some(entry) = entry {
89                let interest = entry.1.interest;
90                let handler = entry.1.handler.as_ref();
91                if (interest.is_readable() && is_readable)
92                    || (interest.is_writable() && is_writable)
93                {
94                    handler.handle_event(&event);
95                }
96            }
97        })
98    }
99
100    pub fn spawn_compute<F>(&self, task: F, priority: TaskPriority)
101    where
102        F: FnOnce() + Send + 'static,
103    {
104        self.compute_pool.spawn(task, priority);
105    }
106
107    pub fn get_compute_metrics(&self) -> Arc<ComputePoolMetrics> {
108        self.compute_pool.metrics()
109    }
110
111    pub fn get_events(&self) -> Arc<RwLock<Events>> {
112        self.events.clone()
113    }
114}
115
116pub struct ShutdownHandle<'a> {
117    running: &'a AtomicBool,
118    poll_handle: &'a PollHandle,
119}
120
121impl ShutdownHandle<'_> {
122    pub fn shutdown(&self) {
123        self.running.store(false, Ordering::SeqCst);
124        self.poll_handle.wake().unwrap();
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::handler::*;
132    use mio::{Interest, Token};
133    use std::sync::{Arc, Condvar, Mutex};
134    use std::time::Duration;
135
136    #[derive(Clone)]
137    struct TestHandler {
138        counter: Arc<Mutex<usize>>,
139        condition: Arc<Condvar>,
140    }
141
142    impl EventHandler for TestHandler {
143        fn handle_event(&self, _event: &Event) {
144            let mut count = self.counter.lock().unwrap();
145            *count += 1;
146            self.condition.notify_one();
147        }
148    }
149
150    #[cfg(unix)]
151    #[test]
152    fn test_reactor_start_stop() {
153        let reactor = Arc::new(Reactor::default());
154        let shutdown_handle = reactor.get_shutdown_handle();
155
156        let reactor_clone = Arc::clone(&reactor);
157        let handle = std::thread::spawn(move || {
158            reactor_clone.run().unwrap();
159        });
160
161        std::thread::sleep(Duration::from_millis(100));
162
163        shutdown_handle.shutdown();
164
165        handle.join().unwrap();
166    }
167
168    #[cfg(unix)]
169    #[test]
170    fn test_with_pipe() -> std::io::Result<()> {
171        use mio::net::UnixStream;
172
173        let reactor =
174            Arc::new(Reactor::new(2, DEFAULT_EVENTS_CAPACITY, DEFAULT_POLL_TIMEOUT_MS).unwrap());
175        let counter = Arc::new(Mutex::new(0));
176        let condition = Arc::new(Condvar::new());
177
178        let (mut stream1, mut stream2) = UnixStream::pair()?;
179
180        let handler = TestHandler {
181            counter: Arc::clone(&counter),
182            condition: Arc::clone(&condition),
183        };
184
185        let token = Token(1);
186
187        reactor
188            .poll_handle
189            .register(&mut stream1, token, Interest::READABLE, handler)
190            .unwrap();
191
192        let reactor_clone = Arc::clone(&reactor);
193        let handle = std::thread::spawn(move || {
194            // Poll once
195            let events_result = {
196                let mut events = reactor_clone.events.write().unwrap();
197                reactor_clone
198                    .poll_handle
199                    .poll(&mut events, Some(Duration::from_millis(100)))
200            };
201
202            if events_result.is_ok() {
203                let events = reactor_clone.events.read().unwrap();
204                for event in events.iter() {
205                    let _ = reactor_clone.dispatch_event(event.clone());
206                }
207            }
208        });
209
210        std::io::Write::write_all(&mut stream2, b"test data")?;
211
212        handle.join().unwrap();
213
214        let count = counter.lock().unwrap();
215        let result = condition
216            .wait_timeout(count, Duration::from_millis(500))
217            .unwrap();
218
219        if !result.1.timed_out() {
220            assert_eq!(*result.0, 1);
221        }
222
223        Ok(())
224    }
225
226    #[test]
227    fn test_with_tcp() -> std::io::Result<()> {
228        use mio::net::{TcpListener, TcpStream};
229        use std::net::SocketAddr;
230
231        let reactor =
232            Arc::new(Reactor::new(2, DEFAULT_EVENTS_CAPACITY, DEFAULT_POLL_TIMEOUT_MS).unwrap());
233        let counter = Arc::new(Mutex::new(0));
234        let condition = Arc::new(Condvar::new());
235
236        // Create a TCP listener on localhost
237        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
238        let mut listener = TcpListener::bind(addr)?;
239        let listener_addr = listener.local_addr()?;
240
241        let handler = TestHandler {
242            counter: Arc::clone(&counter),
243            condition: Arc::clone(&condition),
244        };
245
246        let token = Token(1);
247
248        reactor
249            .poll_handle
250            .register(&mut listener, token, Interest::READABLE, handler)
251            .unwrap();
252
253        let reactor_clone = Arc::clone(&reactor);
254        let handle = std::thread::spawn(move || {
255            // Poll once
256            let events_result = {
257                let mut events = reactor_clone.events.write().unwrap();
258                reactor_clone
259                    .poll_handle
260                    .poll(&mut events, Some(Duration::from_millis(100)))
261            };
262
263            if events_result.is_ok() {
264                let events = reactor_clone.events.read().unwrap();
265                for event in events.iter() {
266                    let _ = reactor_clone.dispatch_event(event.clone());
267                }
268            }
269        });
270
271        // Connect to trigger the event
272        let _stream = TcpStream::connect(listener_addr)?;
273
274        handle.join().unwrap();
275
276        let count = counter.lock().unwrap();
277        let result = condition
278            .wait_timeout(count, Duration::from_millis(500))
279            .unwrap();
280
281        if !result.1.timed_out() {
282            assert_eq!(*result.0, 1);
283        }
284
285        Ok(())
286    }
287}