1use crate::{
2 error::Result,
3 poll::PollHandle,
4 thread_pool::{ComputePoolMetrics, ComputeThreadPool, TaskPriority, ThreadPool},
5};
6use mio::{event::Event, Events};
7use std::{
8 sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc, RwLock,
11 },
12 time::Duration,
13};
14
15pub const DEFAULT_EVENTS_CAPACITY: usize = 1024;
16pub const DEFAULT_POLL_TIMEOUT_MS: u64 = 150;
17
18pub struct Reactor {
19 pub(crate) poll_handle: PollHandle,
20 events: Arc<RwLock<Events>>,
21 pool: ThreadPool,
22 compute_pool: ComputeThreadPool,
23 running: AtomicBool,
24 poll_timeout_ms: u64,
25}
26
27impl Default for Reactor {
28 fn default() -> Self {
29 Self {
30 poll_handle: PollHandle::new().unwrap(),
31 events: Arc::new(RwLock::new(Events::with_capacity(DEFAULT_EVENTS_CAPACITY))),
32 pool: ThreadPool::default(),
33 compute_pool: ComputeThreadPool::new(
34 std::thread::available_parallelism()
35 .map(|n| n.get())
36 .unwrap_or(4),
37 ),
38 running: AtomicBool::new(false),
39 poll_timeout_ms: DEFAULT_POLL_TIMEOUT_MS,
40 }
41 }
42}
43
44impl Reactor {
45 pub fn new(pool_size: usize, events_capacity: usize, poll_timeout_ms: u64) -> Result<Self> {
46 Ok(Self {
47 poll_handle: PollHandle::new()?,
48 events: Arc::new(RwLock::new(Events::with_capacity(events_capacity))),
49 pool: ThreadPool::new(pool_size),
50 compute_pool: ComputeThreadPool::default(),
51 running: AtomicBool::new(false),
52 poll_timeout_ms,
53 })
54 }
55
56 pub fn run(&self) -> Result<()> {
57 self.running.store(true, Ordering::SeqCst);
58
59 while self.running.load(Ordering::SeqCst) {
60 let _ = self.poll_handle.poll(
61 &mut self.events.write().unwrap(),
62 Some(Duration::from_millis(self.poll_timeout_ms)),
63 )?;
64
65 for event in self.events.read().unwrap().iter() {
66 self.dispatch_event(event.clone())?;
67 }
68 }
69 Ok(())
70 }
71
72 pub fn get_shutdown_handle(&self) -> ShutdownHandle<'_> {
73 ShutdownHandle {
74 running: &self.running,
75 poll_handle: &self.poll_handle,
76 }
77 }
78
79 pub fn dispatch_event(&self, event: Event) -> Result<()> {
80 let token = event.token();
81 let is_readable = event.is_readable();
82 let is_writable = event.is_writable();
83
84 let registry = self.poll_handle.get_registery();
85
86 self.pool.exec(move || {
87 let entry = registry.get(&token);
88 if let Some(entry) = entry {
89 let interest = entry.1.interest;
90 let handler = entry.1.handler.as_ref();
91 if (interest.is_readable() && is_readable)
92 || (interest.is_writable() && is_writable)
93 {
94 handler.handle_event(&event);
95 }
96 }
97 })
98 }
99
100 pub fn spawn_compute<F>(&self, task: F, priority: TaskPriority)
101 where
102 F: FnOnce() + Send + 'static,
103 {
104 self.compute_pool.spawn(task, priority);
105 }
106
107 pub fn get_compute_metrics(&self) -> Arc<ComputePoolMetrics> {
108 self.compute_pool.metrics()
109 }
110
111 pub fn get_events(&self) -> Arc<RwLock<Events>> {
112 self.events.clone()
113 }
114}
115
116pub struct ShutdownHandle<'a> {
117 running: &'a AtomicBool,
118 poll_handle: &'a PollHandle,
119}
120
121impl ShutdownHandle<'_> {
122 pub fn shutdown(&self) {
123 self.running.store(false, Ordering::SeqCst);
124 self.poll_handle.wake().unwrap();
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::handler::*;
132 use mio::{Interest, Token};
133 use std::sync::{Arc, Condvar, Mutex};
134 use std::time::Duration;
135
136 #[derive(Clone)]
137 struct TestHandler {
138 counter: Arc<Mutex<usize>>,
139 condition: Arc<Condvar>,
140 }
141
142 impl EventHandler for TestHandler {
143 fn handle_event(&self, _event: &Event) {
144 let mut count = self.counter.lock().unwrap();
145 *count += 1;
146 self.condition.notify_one();
147 }
148 }
149
150 #[cfg(unix)]
151 #[test]
152 fn test_reactor_start_stop() {
153 let reactor = Arc::new(Reactor::default());
154 let shutdown_handle = reactor.get_shutdown_handle();
155
156 let reactor_clone = Arc::clone(&reactor);
157 let handle = std::thread::spawn(move || {
158 reactor_clone.run().unwrap();
159 });
160
161 std::thread::sleep(Duration::from_millis(100));
162
163 shutdown_handle.shutdown();
164
165 handle.join().unwrap();
166 }
167
168 #[cfg(unix)]
169 #[test]
170 fn test_with_pipe() -> std::io::Result<()> {
171 use mio::net::UnixStream;
172
173 let reactor =
174 Arc::new(Reactor::new(2, DEFAULT_EVENTS_CAPACITY, DEFAULT_POLL_TIMEOUT_MS).unwrap());
175 let counter = Arc::new(Mutex::new(0));
176 let condition = Arc::new(Condvar::new());
177
178 let (mut stream1, mut stream2) = UnixStream::pair()?;
179
180 let handler = TestHandler {
181 counter: Arc::clone(&counter),
182 condition: Arc::clone(&condition),
183 };
184
185 let token = Token(1);
186
187 reactor
188 .poll_handle
189 .register(&mut stream1, token, Interest::READABLE, handler)
190 .unwrap();
191
192 let reactor_clone = Arc::clone(&reactor);
193 let handle = std::thread::spawn(move || {
194 let events_result = {
196 let mut events = reactor_clone.events.write().unwrap();
197 reactor_clone
198 .poll_handle
199 .poll(&mut events, Some(Duration::from_millis(100)))
200 };
201
202 if events_result.is_ok() {
203 let events = reactor_clone.events.read().unwrap();
204 for event in events.iter() {
205 let _ = reactor_clone.dispatch_event(event.clone());
206 }
207 }
208 });
209
210 std::io::Write::write_all(&mut stream2, b"test data")?;
211
212 handle.join().unwrap();
213
214 let count = counter.lock().unwrap();
215 let result = condition
216 .wait_timeout(count, Duration::from_millis(500))
217 .unwrap();
218
219 if !result.1.timed_out() {
220 assert_eq!(*result.0, 1);
221 }
222
223 Ok(())
224 }
225
226 #[test]
227 fn test_with_tcp() -> std::io::Result<()> {
228 use mio::net::{TcpListener, TcpStream};
229 use std::net::SocketAddr;
230
231 let reactor =
232 Arc::new(Reactor::new(2, DEFAULT_EVENTS_CAPACITY, DEFAULT_POLL_TIMEOUT_MS).unwrap());
233 let counter = Arc::new(Mutex::new(0));
234 let condition = Arc::new(Condvar::new());
235
236 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
238 let mut listener = TcpListener::bind(addr)?;
239 let listener_addr = listener.local_addr()?;
240
241 let handler = TestHandler {
242 counter: Arc::clone(&counter),
243 condition: Arc::clone(&condition),
244 };
245
246 let token = Token(1);
247
248 reactor
249 .poll_handle
250 .register(&mut listener, token, Interest::READABLE, handler)
251 .unwrap();
252
253 let reactor_clone = Arc::clone(&reactor);
254 let handle = std::thread::spawn(move || {
255 let events_result = {
257 let mut events = reactor_clone.events.write().unwrap();
258 reactor_clone
259 .poll_handle
260 .poll(&mut events, Some(Duration::from_millis(100)))
261 };
262
263 if events_result.is_ok() {
264 let events = reactor_clone.events.read().unwrap();
265 for event in events.iter() {
266 let _ = reactor_clone.dispatch_event(event.clone());
267 }
268 }
269 });
270
271 let _stream = TcpStream::connect(listener_addr)?;
273
274 handle.join().unwrap();
275
276 let count = counter.lock().unwrap();
277 let result = condition
278 .wait_timeout(count, Duration::from_millis(500))
279 .unwrap();
280
281 if !result.1.timed_out() {
282 assert_eq!(*result.0, 1);
283 }
284
285 Ok(())
286 }
287}