1use crate::{
2 error::Result,
3 poll::PollHandle,
4 thread_pool::{ComputePoolMetrics, ComputeThreadPool, TaskPriority, ThreadPool},
5};
6use mio::{event::Event, Events};
7use std::{
8 sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc, RwLock,
11 },
12 time::Duration,
13};
14
15pub const DEFAULT_EVENTS_CAPACITY: usize = 1024;
16pub const DEFAULT_POLL_TIMEOUT_MS: u64 = 150;
17
18pub struct ReactorOptions {
19 pub direct_dispatch: bool,
20}
21
22pub struct Reactor {
23 pub(crate) poll_handle: PollHandle,
24 events: Arc<RwLock<Events>>,
25 pool: ThreadPool,
26 compute_pool: ComputeThreadPool,
27 running: AtomicBool,
28 poll_timeout_ms: u64,
29 options: ReactorOptions,
30}
31
32impl Default for Reactor {
33 fn default() -> Self {
34 Self {
35 poll_handle: PollHandle::new().unwrap(),
36 events: Arc::new(RwLock::new(Events::with_capacity(DEFAULT_EVENTS_CAPACITY))),
37 pool: ThreadPool::default(),
38 compute_pool: ComputeThreadPool::new(
39 std::thread::available_parallelism()
40 .map(|n| n.get())
41 .unwrap_or(4),
42 ),
43 running: AtomicBool::new(false),
44 poll_timeout_ms: DEFAULT_POLL_TIMEOUT_MS,
45 options: ReactorOptions {
46 direct_dispatch: false,
47 },
48 }
49 }
50}
51
52impl Reactor {
53 pub fn new(pool_size: usize, events_capacity: usize, poll_timeout_ms: u64) -> Result<Self> {
54 Ok(Self {
55 poll_handle: PollHandle::new()?,
56 events: Arc::new(RwLock::new(Events::with_capacity(events_capacity))),
57 pool: ThreadPool::new(pool_size),
58 compute_pool: ComputeThreadPool::default(),
59 running: AtomicBool::new(false),
60 poll_timeout_ms,
61 options: ReactorOptions {
62 direct_dispatch: false,
63 },
64 })
65 }
66
67 pub fn new_with_options(
68 pool_size: usize,
69 events_capacity: usize,
70 poll_timeout_ms: u64,
71 options: ReactorOptions,
72 ) -> Result<Self> {
73 Ok(Self {
74 poll_handle: PollHandle::new()?,
75 events: Arc::new(RwLock::new(Events::with_capacity(events_capacity))),
76 pool: ThreadPool::new(pool_size),
77 compute_pool: ComputeThreadPool::default(),
78 running: AtomicBool::new(false),
79 poll_timeout_ms,
80 options,
81 })
82 }
83
84 pub fn run(&self) -> Result<()> {
85 self.running.store(true, Ordering::SeqCst);
86
87 while self.running.load(Ordering::SeqCst) {
88 let _ = self.poll_handle.poll(
89 &mut self.events.write().unwrap(),
90 Some(Duration::from_millis(self.poll_timeout_ms)),
91 )?;
92
93 for event in self.events.read().unwrap().iter() {
94 self.dispatch_event(event.clone())?;
95 }
96 }
97 Ok(())
98 }
99
100 pub fn get_shutdown_handle(&self) -> ShutdownHandle<'_> {
101 ShutdownHandle {
102 running: &self.running,
103 poll_handle: &self.poll_handle,
104 }
105 }
106
107 pub fn dispatch_event(&self, event: Event) -> Result<()> {
108 let token = event.token();
109 let is_readable = event.is_readable();
110 let is_writable = event.is_writable();
111
112 let registry = self.poll_handle.get_registery();
113
114 if self.options.direct_dispatch {
115 if let Some(entry) = registry.get(&token) {
117 let interest = entry.1.interest;
118 let handler = entry.1.handler.as_ref();
119 if (interest.is_readable() && is_readable)
120 || (interest.is_writable() && is_writable)
121 {
122 handler.handle_event(&event);
123 }
124 }
125 Ok(())
126 } else {
127 self.pool.exec(move || {
129 let entry = registry.get(&token);
130 if let Some(entry) = entry {
131 let interest = entry.1.interest;
132 let handler = entry.1.handler.as_ref();
133 if (interest.is_readable() && is_readable)
134 || (interest.is_writable() && is_writable)
135 {
136 handler.handle_event(&event);
137 }
138 }
139 })
140 }
141 }
142
143 pub fn spawn_compute<F>(&self, task: F, priority: TaskPriority)
144 where
145 F: FnOnce() + Send + 'static,
146 {
147 self.compute_pool.spawn(task, priority);
148 }
149
150 pub fn get_compute_metrics(&self) -> Arc<ComputePoolMetrics> {
151 self.compute_pool.metrics()
152 }
153
154 pub fn get_events(&self) -> Arc<RwLock<Events>> {
155 self.events.clone()
156 }
157}
158
159pub struct ShutdownHandle<'a> {
160 running: &'a AtomicBool,
161 poll_handle: &'a PollHandle,
162}
163
164impl ShutdownHandle<'_> {
165 pub fn shutdown(&self) {
166 self.running.store(false, Ordering::SeqCst);
167 self.poll_handle.wake().unwrap();
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::handler::*;
175 use mio::{Interest, Token};
176 use std::sync::{Arc, Condvar, Mutex};
177 use std::time::Duration;
178
179 #[derive(Clone)]
180 struct TestHandler {
181 counter: Arc<Mutex<usize>>,
182 condition: Arc<Condvar>,
183 }
184
185 impl EventHandler for TestHandler {
186 fn handle_event(&self, _event: &Event) {
187 let mut count = self.counter.lock().unwrap();
188 *count += 1;
189 self.condition.notify_one();
190 }
191 }
192
193 #[cfg(unix)]
194 #[test]
195 fn test_reactor_start_stop() {
196 let reactor = Arc::new(Reactor::default());
197 let shutdown_handle = reactor.get_shutdown_handle();
198
199 let reactor_clone = Arc::clone(&reactor);
200 let handle = std::thread::spawn(move || {
201 reactor_clone.run().unwrap();
202 });
203
204 std::thread::sleep(Duration::from_millis(100));
205
206 shutdown_handle.shutdown();
207
208 handle.join().unwrap();
209 }
210
211 #[cfg(unix)]
212 #[test]
213 fn test_with_pipe() -> std::io::Result<()> {
214 use mio::net::UnixStream;
215
216 let reactor =
217 Arc::new(Reactor::new(2, DEFAULT_EVENTS_CAPACITY, DEFAULT_POLL_TIMEOUT_MS).unwrap());
218 let counter = Arc::new(Mutex::new(0));
219 let condition = Arc::new(Condvar::new());
220
221 let (mut stream1, mut stream2) = UnixStream::pair()?;
222
223 let handler = TestHandler {
224 counter: Arc::clone(&counter),
225 condition: Arc::clone(&condition),
226 };
227
228 let token = Token(1);
229
230 reactor
231 .poll_handle
232 .register(&mut stream1, token, Interest::READABLE, handler)
233 .unwrap();
234
235 let reactor_clone = Arc::clone(&reactor);
236 let handle = std::thread::spawn(move || {
237 let events_result = {
239 let mut events = reactor_clone.events.write().unwrap();
240 reactor_clone
241 .poll_handle
242 .poll(&mut events, Some(Duration::from_millis(100)))
243 };
244
245 if events_result.is_ok() {
246 let events = reactor_clone.events.read().unwrap();
247 for event in events.iter() {
248 let _ = reactor_clone.dispatch_event(event.clone());
249 }
250 }
251 });
252
253 std::io::Write::write_all(&mut stream2, b"test data")?;
254
255 handle.join().unwrap();
256
257 let count = counter.lock().unwrap();
258 let result = condition
259 .wait_timeout(count, Duration::from_millis(500))
260 .unwrap();
261
262 if !result.1.timed_out() {
263 assert_eq!(*result.0, 1);
264 }
265
266 Ok(())
267 }
268
269 #[test]
270 fn test_with_tcp() -> std::io::Result<()> {
271 use mio::net::{TcpListener, TcpStream};
272 use std::net::SocketAddr;
273
274 let reactor =
275 Arc::new(Reactor::new(2, DEFAULT_EVENTS_CAPACITY, DEFAULT_POLL_TIMEOUT_MS).unwrap());
276 let counter = Arc::new(Mutex::new(0));
277 let condition = Arc::new(Condvar::new());
278
279 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
281 let mut listener = TcpListener::bind(addr)?;
282 let listener_addr = listener.local_addr()?;
283
284 let handler = TestHandler {
285 counter: Arc::clone(&counter),
286 condition: Arc::clone(&condition),
287 };
288
289 let token = Token(1);
290
291 reactor
292 .poll_handle
293 .register(&mut listener, token, Interest::READABLE, handler)
294 .unwrap();
295
296 let reactor_clone = Arc::clone(&reactor);
297 let handle = std::thread::spawn(move || {
298 let events_result = {
300 let mut events = reactor_clone.events.write().unwrap();
301 reactor_clone
302 .poll_handle
303 .poll(&mut events, Some(Duration::from_millis(100)))
304 };
305
306 if events_result.is_ok() {
307 let events = reactor_clone.events.read().unwrap();
308 for event in events.iter() {
309 let _ = reactor_clone.dispatch_event(event.clone());
310 }
311 }
312 });
313
314 let _stream = TcpStream::connect(listener_addr)?;
316
317 handle.join().unwrap();
318
319 let count = counter.lock().unwrap();
320 let result = condition
321 .wait_timeout(count, Duration::from_millis(500))
322 .unwrap();
323
324 if !result.1.timed_out() {
325 assert_eq!(*result.0, 1);
326 }
327
328 Ok(())
329 }
330}