Skip to main content

mill_io/
reactor.rs

1use crate::{
2    error::Result,
3    poll::PollHandle,
4    thread_pool::{ComputePoolMetrics, ComputeThreadPool, TaskPriority, ThreadPool},
5};
6use mio::{event::Event, Events};
7use std::{
8    sync::{
9        atomic::{AtomicBool, Ordering},
10        Arc, RwLock,
11    },
12    time::Duration,
13};
14
15pub const DEFAULT_EVENTS_CAPACITY: usize = 1024;
16pub const DEFAULT_POLL_TIMEOUT_MS: u64 = 150;
17
18pub struct ReactorOptions {
19    pub direct_dispatch: bool,
20}
21
22pub struct Reactor {
23    pub(crate) poll_handle: PollHandle,
24    events: Arc<RwLock<Events>>,
25    pool: ThreadPool,
26    compute_pool: ComputeThreadPool,
27    running: AtomicBool,
28    poll_timeout_ms: u64,
29    options: ReactorOptions,
30}
31
32impl Default for Reactor {
33    fn default() -> Self {
34        Self {
35            poll_handle: PollHandle::new().unwrap(),
36            events: Arc::new(RwLock::new(Events::with_capacity(DEFAULT_EVENTS_CAPACITY))),
37            pool: ThreadPool::default(),
38            compute_pool: ComputeThreadPool::new(
39                std::thread::available_parallelism()
40                    .map(|n| n.get())
41                    .unwrap_or(4),
42            ),
43            running: AtomicBool::new(false),
44            poll_timeout_ms: DEFAULT_POLL_TIMEOUT_MS,
45            options: ReactorOptions {
46                direct_dispatch: false,
47            },
48        }
49    }
50}
51
52impl Reactor {
53    pub fn new(pool_size: usize, events_capacity: usize, poll_timeout_ms: u64) -> Result<Self> {
54        Ok(Self {
55            poll_handle: PollHandle::new()?,
56            events: Arc::new(RwLock::new(Events::with_capacity(events_capacity))),
57            pool: ThreadPool::new(pool_size),
58            compute_pool: ComputeThreadPool::default(),
59            running: AtomicBool::new(false),
60            poll_timeout_ms,
61            options: ReactorOptions {
62                direct_dispatch: false,
63            },
64        })
65    }
66
67    pub fn new_with_options(
68        pool_size: usize,
69        events_capacity: usize,
70        poll_timeout_ms: u64,
71        options: ReactorOptions,
72    ) -> Result<Self> {
73        Ok(Self {
74            poll_handle: PollHandle::new()?,
75            events: Arc::new(RwLock::new(Events::with_capacity(events_capacity))),
76            pool: ThreadPool::new(pool_size),
77            compute_pool: ComputeThreadPool::default(),
78            running: AtomicBool::new(false),
79            poll_timeout_ms,
80            options,
81        })
82    }
83
84    pub fn run(&self) -> Result<()> {
85        self.running.store(true, Ordering::SeqCst);
86
87        while self.running.load(Ordering::SeqCst) {
88            let _ = self.poll_handle.poll(
89                &mut self.events.write().unwrap(),
90                Some(Duration::from_millis(self.poll_timeout_ms)),
91            )?;
92
93            for event in self.events.read().unwrap().iter() {
94                self.dispatch_event(event.clone())?;
95            }
96        }
97        Ok(())
98    }
99
100    pub fn get_shutdown_handle(&self) -> ShutdownHandle<'_> {
101        ShutdownHandle {
102            running: &self.running,
103            poll_handle: &self.poll_handle,
104        }
105    }
106
107    pub fn dispatch_event(&self, event: Event) -> Result<()> {
108        let token = event.token();
109        let is_readable = event.is_readable();
110        let is_writable = event.is_writable();
111
112        let registry = self.poll_handle.get_registery();
113
114        if self.options.direct_dispatch {
115            // FAST PATH: Direct execution on reactor thread
116            if let Some(entry) = registry.get(&token) {
117                let interest = entry.1.interest;
118                let handler = entry.1.handler.as_ref();
119                if (interest.is_readable() && is_readable)
120                    || (interest.is_writable() && is_writable)
121                {
122                    handler.handle_event(&event);
123                }
124            }
125            Ok(())
126        } else {
127            // SLOW PATH: Dispatch to thread pool
128            self.pool.exec(move || {
129                let entry = registry.get(&token);
130                if let Some(entry) = entry {
131                    let interest = entry.1.interest;
132                    let handler = entry.1.handler.as_ref();
133                    if (interest.is_readable() && is_readable)
134                        || (interest.is_writable() && is_writable)
135                    {
136                        handler.handle_event(&event);
137                    }
138                }
139            })
140        }
141    }
142
143    pub fn spawn_compute<F>(&self, task: F, priority: TaskPriority)
144    where
145        F: FnOnce() + Send + 'static,
146    {
147        self.compute_pool.spawn(task, priority);
148    }
149
150    pub fn get_compute_metrics(&self) -> Arc<ComputePoolMetrics> {
151        self.compute_pool.metrics()
152    }
153
154    pub fn get_events(&self) -> Arc<RwLock<Events>> {
155        self.events.clone()
156    }
157}
158
159pub struct ShutdownHandle<'a> {
160    running: &'a AtomicBool,
161    poll_handle: &'a PollHandle,
162}
163
164impl ShutdownHandle<'_> {
165    pub fn shutdown(&self) {
166        self.running.store(false, Ordering::SeqCst);
167        self.poll_handle.wake().unwrap();
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::handler::*;
175    use mio::{Interest, Token};
176    use std::sync::{Arc, Condvar, Mutex};
177    use std::time::Duration;
178
179    #[derive(Clone)]
180    struct TestHandler {
181        counter: Arc<Mutex<usize>>,
182        condition: Arc<Condvar>,
183    }
184
185    impl EventHandler for TestHandler {
186        fn handle_event(&self, _event: &Event) {
187            let mut count = self.counter.lock().unwrap();
188            *count += 1;
189            self.condition.notify_one();
190        }
191    }
192
193    #[cfg(unix)]
194    #[test]
195    fn test_reactor_start_stop() {
196        let reactor = Arc::new(Reactor::default());
197        let shutdown_handle = reactor.get_shutdown_handle();
198
199        let reactor_clone = Arc::clone(&reactor);
200        let handle = std::thread::spawn(move || {
201            reactor_clone.run().unwrap();
202        });
203
204        std::thread::sleep(Duration::from_millis(100));
205
206        shutdown_handle.shutdown();
207
208        handle.join().unwrap();
209    }
210
211    #[cfg(unix)]
212    #[test]
213    fn test_with_pipe() -> std::io::Result<()> {
214        use mio::net::UnixStream;
215
216        let reactor =
217            Arc::new(Reactor::new(2, DEFAULT_EVENTS_CAPACITY, DEFAULT_POLL_TIMEOUT_MS).unwrap());
218        let counter = Arc::new(Mutex::new(0));
219        let condition = Arc::new(Condvar::new());
220
221        let (mut stream1, mut stream2) = UnixStream::pair()?;
222
223        let handler = TestHandler {
224            counter: Arc::clone(&counter),
225            condition: Arc::clone(&condition),
226        };
227
228        let token = Token(1);
229
230        reactor
231            .poll_handle
232            .register(&mut stream1, token, Interest::READABLE, handler)
233            .unwrap();
234
235        let reactor_clone = Arc::clone(&reactor);
236        let handle = std::thread::spawn(move || {
237            // Poll once
238            let events_result = {
239                let mut events = reactor_clone.events.write().unwrap();
240                reactor_clone
241                    .poll_handle
242                    .poll(&mut events, Some(Duration::from_millis(100)))
243            };
244
245            if events_result.is_ok() {
246                let events = reactor_clone.events.read().unwrap();
247                for event in events.iter() {
248                    let _ = reactor_clone.dispatch_event(event.clone());
249                }
250            }
251        });
252
253        std::io::Write::write_all(&mut stream2, b"test data")?;
254
255        handle.join().unwrap();
256
257        let count = counter.lock().unwrap();
258        let result = condition
259            .wait_timeout(count, Duration::from_millis(500))
260            .unwrap();
261
262        if !result.1.timed_out() {
263            assert_eq!(*result.0, 1);
264        }
265
266        Ok(())
267    }
268
269    #[test]
270    fn test_with_tcp() -> std::io::Result<()> {
271        use mio::net::{TcpListener, TcpStream};
272        use std::net::SocketAddr;
273
274        let reactor =
275            Arc::new(Reactor::new(2, DEFAULT_EVENTS_CAPACITY, DEFAULT_POLL_TIMEOUT_MS).unwrap());
276        let counter = Arc::new(Mutex::new(0));
277        let condition = Arc::new(Condvar::new());
278
279        // Create a TCP listener on localhost
280        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
281        let mut listener = TcpListener::bind(addr)?;
282        let listener_addr = listener.local_addr()?;
283
284        let handler = TestHandler {
285            counter: Arc::clone(&counter),
286            condition: Arc::clone(&condition),
287        };
288
289        let token = Token(1);
290
291        reactor
292            .poll_handle
293            .register(&mut listener, token, Interest::READABLE, handler)
294            .unwrap();
295
296        let reactor_clone = Arc::clone(&reactor);
297        let handle = std::thread::spawn(move || {
298            // Poll once
299            let events_result = {
300                let mut events = reactor_clone.events.write().unwrap();
301                reactor_clone
302                    .poll_handle
303                    .poll(&mut events, Some(Duration::from_millis(100)))
304            };
305
306            if events_result.is_ok() {
307                let events = reactor_clone.events.read().unwrap();
308                for event in events.iter() {
309                    let _ = reactor_clone.dispatch_event(event.clone());
310                }
311            }
312        });
313
314        // Connect to trigger the event
315        let _stream = TcpStream::connect(listener_addr)?;
316
317        handle.join().unwrap();
318
319        let count = counter.lock().unwrap();
320        let result = condition
321            .wait_timeout(count, Duration::from_millis(500))
322            .unwrap();
323
324        if !result.1.timed_out() {
325            assert_eq!(*result.0, 1);
326        }
327
328        Ok(())
329    }
330}