1use std::{
20 io,
21 ops::Deref,
22 sync::{
23 atomic::{AtomicU64, Ordering},
24 Arc, OnceLock,
25 },
26 task::Waker,
27 time::{Duration, Instant},
28};
29
30use dashmap::DashMap;
31use mio::{
32 event::{self, Source},
33 Interest, Token,
34};
35use rasi_syscall::{CancelablePoll, Handle};
36
37pub(crate) struct MioSocket<S: Source> {
39 pub(crate) token: Token,
41 pub(crate) socket: S,
43}
44
45impl<S: Source> From<(Token, S)> for MioSocket<S> {
46 fn from(value: (Token, S)) -> Self {
47 Self {
48 token: value.0,
49 socket: value.1,
50 }
51 }
52}
53
54impl<S: Source> Deref for MioSocket<S> {
55 type Target = S;
56 fn deref(&self) -> &Self::Target {
57 &self.socket
58 }
59}
60
61impl<S: Source> Drop for MioSocket<S> {
62 fn drop(&mut self) {
63 if global_reactor().deregister(&mut self.socket).is_err() {}
64 }
65}
66
67pub(crate) fn would_block<T, F>(
72 token: Token,
73 waker: Waker,
74 interests: Interest,
75 mut f: F,
76) -> CancelablePoll<io::Result<T>>
77where
78 F: FnMut() -> io::Result<T>,
79{
80 global_reactor().once(token, interests, waker);
81
82 loop {
83 match f() {
84 Ok(t) => {
85 return {
86 global_reactor().remove_listeners(token, interests);
87
88 CancelablePoll::Ready(Ok(t))
89 }
90 }
91 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
92 return CancelablePoll::Pending(Handle::new(()));
93 }
94 Err(err) if err.kind() == io::ErrorKind::Interrupted => {
95 continue;
96 }
97 Err(err) => {
98 global_reactor().remove_listeners(token, interests);
99 return CancelablePoll::Ready(Err(err));
100 }
101 }
102 }
103}
104
105struct Timewheel {
107 tick_interval: u64,
108 timers: DashMap<u64, boxcar::Vec<Token>>,
110 ticks: AtomicU64,
112 start_instant: Instant,
114 timer_count: AtomicU64,
116}
117
118impl Timewheel {
119 fn new(tick_interval: Duration) -> Self {
120 Self {
121 tick_interval: tick_interval.as_micros() as u64,
122 ticks: Default::default(),
123 start_instant: Instant::now(),
124 timer_count: Default::default(),
125 timers: Default::default(),
126 }
127 }
128
129 #[allow(unused)]
131 fn timers(&self) -> u64 {
132 self.timer_count.load(Ordering::Relaxed)
133 }
134
135 pub fn new_timer(&self, token: Token, deadline: Instant) -> Option<u64> {
137 let ticks = (deadline - self.start_instant).as_micros() as u64 / self.tick_interval;
138
139 let ticks = ticks as u64;
140
141 if self
142 .ticks
143 .fetch_update(Ordering::Release, Ordering::Acquire, |current| {
144 if current > ticks {
145 None
146 } else {
147 Some(current)
148 }
149 })
150 .is_err()
151 {
152 return None;
153 }
154
155 self.timers.entry(ticks).or_default().push(token);
156
157 if self.ticks.load(Ordering::SeqCst) > ticks {
158 if self.timers.remove(&ticks).is_some() {
159 return None;
160 }
161 }
162
163 if self
164 .ticks
165 .fetch_update(Ordering::Release, Ordering::Acquire, |current| {
166 if current > ticks {
167 if self.timers.remove(&ticks).is_some() {
168 return None;
169 } else {
170 return Some(current);
171 }
172 } else {
173 return Some(current);
174 }
175 })
176 .is_err()
177 {
178 return None;
179 }
180
181 self.timer_count.fetch_add(1, Ordering::SeqCst);
182
183 Some(ticks)
184 }
185
186 pub fn next_tick(&self) -> Option<Vec<Token>> {
188 loop {
189 let current = self.ticks.load(Ordering::Acquire);
190
191 let instant_duration = Instant::now() - self.start_instant;
192
193 let ticks = instant_duration.as_micros() as u64 / self.tick_interval;
194
195 assert!(current <= ticks);
196
197 if current == ticks {
198 return None;
199 }
200
201 if self
202 .ticks
203 .compare_exchange(current, ticks, Ordering::AcqRel, Ordering::Relaxed)
204 .is_ok()
205 {
206 let mut timeout_timers = vec![];
207
208 for i in current..ticks {
209 if let Some((_, queue)) = self.timers.remove(&i) {
210 for t in queue.into_iter() {
211 timeout_timers.push(t);
212 }
213 }
214 }
215
216 return Some(timeout_timers);
217 }
218 }
219 }
220}
221
222pub struct Reactor {
227 mio_registry: mio::Registry,
229 read_op_wakers: DashMap<Token, Waker>,
231 write_op_wakers: DashMap<Token, Waker>,
233 timewheel: Timewheel,
235}
236
237pub type ArcReactor = Arc<Reactor>;
239
240impl Reactor {
241 fn new(tick_interval: Duration) -> io::Result<ArcReactor> {
242 let mio_poll = mio::Poll::new()?;
243
244 let mio_registry = mio_poll.registry().try_clone()?;
245
246 let reactor = Arc::new(Reactor {
247 mio_registry,
248 read_op_wakers: Default::default(),
249 write_op_wakers: Default::default(),
250 timewheel: Timewheel::new(tick_interval),
251 });
252
253 let background = ReactorBackground::new(tick_interval, mio_poll, reactor.clone());
254
255 background.start();
256
257 Ok(reactor)
258 }
259
260 pub fn register<S>(&self, source: &mut S, token: Token, interests: Interest) -> io::Result<()>
262 where
263 S: event::Source + ?Sized,
264 {
265 self.mio_registry.register(source, token, interests)
266 }
267
268 pub fn deregister<S>(&self, source: &mut S) -> io::Result<()>
270 where
271 S: event::Source + ?Sized,
272 {
273 self.mio_registry.deregister(source)
274 }
275
276 pub fn deadline(&self, token: Token, waker: Waker, deadline: Instant) -> Option<u64> {
278 self.write_op_wakers.insert(token, waker);
279
280 if let Some(id) = self.timewheel.new_timer(token, deadline) {
282 Some(id)
283 } else {
284 self.write_op_wakers.remove(&token);
288 None
289 }
290 }
291
292 pub fn once(&self, token: Token, interests: Interest, waker: Waker) {
294 if interests.is_readable() {
295 self.read_op_wakers.insert(token, waker.clone());
296 }
297
298 if interests.is_writable() {
299 self.write_op_wakers.insert(token, waker);
300 }
301 }
302
303 pub fn notify(&self, token: Token, interests: Interest) {
305 if interests.is_readable() {
306 if let Some(waker) = self.read_op_wakers.remove(&token).map(|(_, v)| v) {
307 waker.wake();
308 }
309 }
310
311 if interests.is_writable() {
312 if let Some(waker) = self.write_op_wakers.remove(&token).map(|(_, v)| v) {
313 waker.wake();
314 }
315 }
316 }
317
318 pub fn remove_listeners(&self, token: Token, interests: Interest) {
320 if interests.is_readable() {
321 self.read_op_wakers.remove(&token);
322 }
323
324 if interests.is_writable() {
325 self.write_op_wakers.remove(&token);
326 }
327 }
328}
329
330struct ReactorBackground {
332 mio_poll: mio::Poll,
333 reactor: ArcReactor,
334 tick_interval: Duration,
335}
336
337impl ReactorBackground {
338 fn new(tick_interval: Duration, mio_poll: mio::Poll, reactor: ArcReactor) -> Self {
339 Self {
340 mio_poll,
341 reactor,
342 tick_interval,
343 }
344 }
345
346 fn start(mut self) {
348 std::thread::spawn(move || {
349 self.dispatch_loop();
350 });
351 }
352
353 fn dispatch_loop(&mut self) {
355 let mut events = mio::event::Events::with_capacity(1024);
356
357 loop {
358 self.mio_poll
359 .poll(&mut events, Some(self.tick_interval))
360 .expect("Mio poll panic");
361
362 for event in &events {
363 if event.is_readable() {
364 self.notify(event.token(), Interest::READABLE);
365 }
366
367 if event.is_writable() {
368 self.notify(event.token(), Interest::WRITABLE);
369 }
370 }
371
372 let timeout_timers = self.reactor.timewheel.next_tick();
373
374 if let Some(timeout_timers) = timeout_timers {
375 for token in timeout_timers {
376 self.notify(token, Interest::WRITABLE);
377 }
378 }
379 }
380 }
381
382 fn notify(&self, token: Token, interests: Interest) {
383 self.reactor.notify(token, interests);
384 }
385}
386
387static GLOBAL_REACTOR: OnceLock<ArcReactor> = OnceLock::new();
388
389pub fn start_reactor_with(tick_interval: Duration) {
400 if GLOBAL_REACTOR
401 .set(Reactor::new(tick_interval).unwrap())
402 .is_err()
403 {
404 panic!("Call start_reactor_with twice.");
405 }
406}
407
408pub fn global_reactor() -> ArcReactor {
413 GLOBAL_REACTOR
414 .get_or_init(|| Reactor::new(Duration::from_millis(10)).unwrap())
415 .clone()
416}
417
418#[cfg(test)]
419mod tests {
420 use std::{sync::Barrier, thread::sleep, time::Duration};
421
422 use crate::TokenSequence;
423
424 use super::*;
425
426 #[test]
427 fn test_add_timers() {
428 let threads = 10;
429 let loops = 3usize;
430
431 let time_wheel = Arc::new(Timewheel::new(Duration::from_millis(100)));
432
433 let barrier = Arc::new(Barrier::new(threads));
434
435 let mut handles = vec![];
436
437 for _ in 0..threads {
438 let barrier = barrier.clone();
439
440 let time_wheel = time_wheel.clone();
441
442 handles.push(std::thread::spawn(move || {
443 barrier.wait();
444
445 for i in 0..loops {
446 time_wheel
447 .new_timer(
448 Token::next(),
449 Instant::now() + Duration::from_secs((i + 1) as u64),
450 )
451 .unwrap();
452 }
453 }))
454 }
455
456 for handle in handles {
457 handle.join().unwrap();
458 }
459
460 assert_eq!(time_wheel.timers() as usize, threads * loops);
461
462 let mut handles = vec![];
463
464 let counter = Arc::new(AtomicU64::new(0));
465
466 for _ in 0..threads {
467 let time_wheel = time_wheel.clone();
468
469 let counter = counter.clone();
470
471 handles.push(std::thread::spawn(move || loop {
472 if let Some(timers) = time_wheel.next_tick() {
473 counter.fetch_add(timers.len() as u64, Ordering::SeqCst);
474 }
475
476 if counter.load(Ordering::SeqCst) == (threads * loops) as u64 {
477 break;
478 }
479 }))
480 }
481
482 for handle in handles {
483 handle.join().unwrap();
484 }
485 }
486
487 #[test]
488 fn test_next_tick() {
489 let time_wheel = Timewheel::new(Duration::from_millis(100));
490
491 let token = Token::next();
492 assert_eq!(
493 time_wheel.new_timer(token, Instant::now() + Duration::from_millis(100)),
494 Some(1)
495 );
496
497 sleep(Duration::from_millis(200));
498
499 assert_eq!(time_wheel.next_tick(), Some(vec![token]));
500 }
501}