rasi_default/
reactor.rs

1//! A reactor pattern implementation based on [`mio::Poll`].
2//!
3//! This is the implementation detail of the `syscall`s and is best not used directly.
4//!
5//! You can get the global instance of [`Reactor`] by calling function [`global_reactor`].
6//!
7//! # Examples
8//!
9//! ```no_run
10//! # fn main() {
11//! #
12//! use rasi_default::reactor;
13//!
14//! let reactor = reactor::global_reactor();
15//! #
16//! # }
17//! ```
18//!
19use std::{
20    io,
21    ops::Deref,
22    sync::{
23        atomic::{AtomicU64, Ordering},
24        Arc, OnceLock,
25    },
26    task::Waker,
27    time::{Duration, Instant},
28};
29
30use dashmap::DashMap;
31use mio::{
32    event::{self, Source},
33    Interest, Token,
34};
35use rasi_syscall::{CancelablePoll, Handle};
36
37/// A wrapper of mio event source.
38pub(crate) struct MioSocket<S: Source> {
39    /// Associcated token.
40    pub(crate) token: Token,
41    /// net source type.
42    pub(crate) socket: S,
43}
44
45impl<S: Source> From<(Token, S)> for MioSocket<S> {
46    fn from(value: (Token, S)) -> Self {
47        Self {
48            token: value.0,
49            socket: value.1,
50        }
51    }
52}
53
54impl<S: Source> Deref for MioSocket<S> {
55    type Target = S;
56    fn deref(&self) -> &Self::Target {
57        &self.socket
58    }
59}
60
61impl<S: Source> Drop for MioSocket<S> {
62    fn drop(&mut self) {
63        if global_reactor().deregister(&mut self.socket).is_err() {}
64    }
65}
66
67/// Create a [`CancelablePoll`] instance from [`std::io::Result`] that returns by function `f`.
68///
69/// If the function `f` result error is [`Interrupted`](io::ErrorKind::Interrupted),
70/// `would_block` will call `f` again immediately.
71pub(crate) fn would_block<T, F>(
72    token: Token,
73    waker: Waker,
74    interests: Interest,
75    mut f: F,
76) -> CancelablePoll<io::Result<T>>
77where
78    F: FnMut() -> io::Result<T>,
79{
80    global_reactor().once(token, interests, waker);
81
82    loop {
83        match f() {
84            Ok(t) => {
85                return {
86                    global_reactor().remove_listeners(token, interests);
87
88                    CancelablePoll::Ready(Ok(t))
89                }
90            }
91            Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
92                return CancelablePoll::Pending(Handle::new(()));
93            }
94            Err(err) if err.kind() == io::ErrorKind::Interrupted => {
95                continue;
96            }
97            Err(err) => {
98                global_reactor().remove_listeners(token, interests);
99                return CancelablePoll::Ready(Err(err));
100            }
101        }
102    }
103}
104
105/// The hashed time wheel implementation to handle a massive set of timer tracking tasks.
106struct Timewheel {
107    tick_interval: u64,
108    /// The timer map from ticks to timers queue.
109    timers: DashMap<u64, boxcar::Vec<Token>>,
110    /// Clock ticks that have elapsed since this object created.
111    ticks: AtomicU64,
112    /// The timestamp of this timewheel instance was created
113    start_instant: Instant,
114    /// Aliving timers's count.
115    timer_count: AtomicU64,
116}
117
118impl Timewheel {
119    fn new(tick_interval: Duration) -> Self {
120        Self {
121            tick_interval: tick_interval.as_micros() as u64,
122            ticks: Default::default(),
123            start_instant: Instant::now(),
124            timer_count: Default::default(),
125            timers: Default::default(),
126        }
127    }
128
129    /// Returns aliving timers's count
130    #[allow(unused)]
131    fn timers(&self) -> u64 {
132        self.timer_count.load(Ordering::Relaxed)
133    }
134
135    /// Creates a new timer and returns the timer expiration ticks.
136    pub fn new_timer(&self, token: Token, deadline: Instant) -> Option<u64> {
137        let ticks = (deadline - self.start_instant).as_micros() as u64 / self.tick_interval;
138
139        let ticks = ticks as u64;
140
141        if self
142            .ticks
143            .fetch_update(Ordering::Release, Ordering::Acquire, |current| {
144                if current > ticks {
145                    None
146                } else {
147                    Some(current)
148                }
149            })
150            .is_err()
151        {
152            return None;
153        }
154
155        self.timers.entry(ticks).or_default().push(token);
156
157        if self.ticks.load(Ordering::SeqCst) > ticks {
158            if self.timers.remove(&ticks).is_some() {
159                return None;
160            }
161        }
162
163        if self
164            .ticks
165            .fetch_update(Ordering::Release, Ordering::Acquire, |current| {
166                if current > ticks {
167                    if self.timers.remove(&ticks).is_some() {
168                        return None;
169                    } else {
170                        return Some(current);
171                    }
172                } else {
173                    return Some(current);
174                }
175            })
176            .is_err()
177        {
178            return None;
179        }
180
181        self.timer_count.fetch_add(1, Ordering::SeqCst);
182
183        Some(ticks)
184    }
185
186    /// Forward to next tick, and returns timeout timers.
187    pub fn next_tick(&self) -> Option<Vec<Token>> {
188        loop {
189            let current = self.ticks.load(Ordering::Acquire);
190
191            let instant_duration = Instant::now() - self.start_instant;
192
193            let ticks = instant_duration.as_micros() as u64 / self.tick_interval;
194
195            assert!(current <= ticks);
196
197            if current == ticks {
198                return None;
199            }
200
201            if self
202                .ticks
203                .compare_exchange(current, ticks, Ordering::AcqRel, Ordering::Relaxed)
204                .is_ok()
205            {
206                let mut timeout_timers = vec![];
207
208                for i in current..ticks {
209                    if let Some((_, queue)) = self.timers.remove(&i) {
210                        for t in queue.into_iter() {
211                            timeout_timers.push(t);
212                        }
213                    }
214                }
215
216                return Some(timeout_timers);
217            }
218        }
219    }
220}
221
222/// A reactor pattern implementation based on [`mio::Poll`].
223///
224/// **Note**: This type implement with lockfree structures,
225/// so it is only available on platforms that support atomic operations.
226pub struct Reactor {
227    /// Io resources registry.
228    mio_registry: mio::Registry,
229    /// The pending registry for reading operations.
230    read_op_wakers: DashMap<Token, Waker>,
231    /// The pending registry for writing operations.
232    write_op_wakers: DashMap<Token, Waker>,
233    /// hashed time wheel implementation.
234    timewheel: Timewheel,
235}
236
237/// A thread-safe reference-counting pointer of type [`Reactor`]
238pub type ArcReactor = Arc<Reactor>;
239
240impl Reactor {
241    fn new(tick_interval: Duration) -> io::Result<ArcReactor> {
242        let mio_poll = mio::Poll::new()?;
243
244        let mio_registry = mio_poll.registry().try_clone()?;
245
246        let reactor = Arc::new(Reactor {
247            mio_registry,
248            read_op_wakers: Default::default(),
249            write_op_wakers: Default::default(),
250            timewheel: Timewheel::new(tick_interval),
251        });
252
253        let background = ReactorBackground::new(tick_interval, mio_poll, reactor.clone());
254
255        background.start();
256
257        Ok(reactor)
258    }
259
260    /// Register an [`event::Source`] with the underlying [`mio::Poll`] instance.
261    pub fn register<S>(&self, source: &mut S, token: Token, interests: Interest) -> io::Result<()>
262    where
263        S: event::Source + ?Sized,
264    {
265        self.mio_registry.register(source, token, interests)
266    }
267
268    /// Deregister an [`event::Source`] from the underlying [`mio::Poll`] instance.
269    pub fn deregister<S>(&self, source: &mut S) -> io::Result<()>
270    where
271        S: event::Source + ?Sized,
272    {
273        self.mio_registry.deregister(source)
274    }
275
276    /// Create new `deadline` timer, returns [`None`] if the `deadline` instant is reached.
277    pub fn deadline(&self, token: Token, waker: Waker, deadline: Instant) -> Option<u64> {
278        self.write_op_wakers.insert(token, waker);
279
280        // Adding a timer was successful.
281        if let Some(id) = self.timewheel.new_timer(token, deadline) {
282            Some(id)
283        } else {
284            // Adding a timer fails because the `deadline` has expired.
285            //
286            // So remove the waker from memory immediately.
287            self.write_op_wakers.remove(&token);
288            None
289        }
290    }
291
292    /// Add a [`interests`](Interest) [`listener`](Waker) to this reactor.
293    pub fn once(&self, token: Token, interests: Interest, waker: Waker) {
294        if interests.is_readable() {
295            self.read_op_wakers.insert(token, waker.clone());
296        }
297
298        if interests.is_writable() {
299            self.write_op_wakers.insert(token, waker);
300        }
301    }
302
303    /// notify [`listener`](Waker) by [`Token`] and [`interests`](Interest)
304    pub fn notify(&self, token: Token, interests: Interest) {
305        if interests.is_readable() {
306            if let Some(waker) = self.read_op_wakers.remove(&token).map(|(_, v)| v) {
307                waker.wake();
308            }
309        }
310
311        if interests.is_writable() {
312            if let Some(waker) = self.write_op_wakers.remove(&token).map(|(_, v)| v) {
313                waker.wake();
314            }
315        }
316    }
317
318    /// remove [`listener`](Waker) from this reactor by [`Token`] and [`interests`](Interest)
319    pub fn remove_listeners(&self, token: Token, interests: Interest) {
320        if interests.is_readable() {
321            self.read_op_wakers.remove(&token);
322        }
323
324        if interests.is_writable() {
325            self.write_op_wakers.remove(&token);
326        }
327    }
328}
329
330/// The context of [`Reactor`] background thread.
331struct ReactorBackground {
332    mio_poll: mio::Poll,
333    reactor: ArcReactor,
334    tick_interval: Duration,
335}
336
337impl ReactorBackground {
338    fn new(tick_interval: Duration, mio_poll: mio::Poll, reactor: ArcReactor) -> Self {
339        Self {
340            mio_poll,
341            reactor,
342            tick_interval,
343        }
344    }
345
346    /// Start readiness events dispatch, and consume self.
347    fn start(mut self) {
348        std::thread::spawn(move || {
349            self.dispatch_loop();
350        });
351    }
352
353    /// Readiness event dispatch loop.
354    fn dispatch_loop(&mut self) {
355        let mut events = mio::event::Events::with_capacity(1024);
356
357        loop {
358            self.mio_poll
359                .poll(&mut events, Some(self.tick_interval))
360                .expect("Mio poll panic");
361
362            for event in &events {
363                if event.is_readable() {
364                    self.notify(event.token(), Interest::READABLE);
365                }
366
367                if event.is_writable() {
368                    self.notify(event.token(), Interest::WRITABLE);
369                }
370            }
371
372            let timeout_timers = self.reactor.timewheel.next_tick();
373
374            if let Some(timeout_timers) = timeout_timers {
375                for token in timeout_timers {
376                    self.notify(token, Interest::WRITABLE);
377                }
378            }
379        }
380    }
381
382    fn notify(&self, token: Token, interests: Interest) {
383        self.reactor.notify(token, interests);
384    }
385}
386
387static GLOBAL_REACTOR: OnceLock<ArcReactor> = OnceLock::new();
388
389/// Manually start [`Reactor`] service with providing `tick_interval`.
390///
391/// If `start_reactor_with` is not called at the very beginning of the `main fn`,
392/// [`Reactor`] will run with the default tick_interval = 10ms.
393///
394/// # Panic
395///
396/// Call this function more than once or Call this function after calling any
397/// [`Network`](rasi_syscall::Network) [`Timer`](rasi_syscall::Timer) system interface , will cause a panic with message
398/// `Call start_reactor_with twice.`
399pub fn start_reactor_with(tick_interval: Duration) {
400    if GLOBAL_REACTOR
401        .set(Reactor::new(tick_interval).unwrap())
402        .is_err()
403    {
404        panic!("Call start_reactor_with twice.");
405    }
406}
407
408/// Get the globally registered instance of [`Reactor`].
409///
410/// If call this function before calling [`start_reactor_with`],
411/// the implementation will start [`Reactor`] with tick_interval = 10ms.
412pub fn global_reactor() -> ArcReactor {
413    GLOBAL_REACTOR
414        .get_or_init(|| Reactor::new(Duration::from_millis(10)).unwrap())
415        .clone()
416}
417
418#[cfg(test)]
419mod tests {
420    use std::{sync::Barrier, thread::sleep, time::Duration};
421
422    use crate::TokenSequence;
423
424    use super::*;
425
426    #[test]
427    fn test_add_timers() {
428        let threads = 10;
429        let loops = 3usize;
430
431        let time_wheel = Arc::new(Timewheel::new(Duration::from_millis(100)));
432
433        let barrier = Arc::new(Barrier::new(threads));
434
435        let mut handles = vec![];
436
437        for _ in 0..threads {
438            let barrier = barrier.clone();
439
440            let time_wheel = time_wheel.clone();
441
442            handles.push(std::thread::spawn(move || {
443                barrier.wait();
444
445                for i in 0..loops {
446                    time_wheel
447                        .new_timer(
448                            Token::next(),
449                            Instant::now() + Duration::from_secs((i + 1) as u64),
450                        )
451                        .unwrap();
452                }
453            }))
454        }
455
456        for handle in handles {
457            handle.join().unwrap();
458        }
459
460        assert_eq!(time_wheel.timers() as usize, threads * loops);
461
462        let mut handles = vec![];
463
464        let counter = Arc::new(AtomicU64::new(0));
465
466        for _ in 0..threads {
467            let time_wheel = time_wheel.clone();
468
469            let counter = counter.clone();
470
471            handles.push(std::thread::spawn(move || loop {
472                if let Some(timers) = time_wheel.next_tick() {
473                    counter.fetch_add(timers.len() as u64, Ordering::SeqCst);
474                }
475
476                if counter.load(Ordering::SeqCst) == (threads * loops) as u64 {
477                    break;
478                }
479            }))
480        }
481
482        for handle in handles {
483            handle.join().unwrap();
484        }
485    }
486
487    #[test]
488    fn test_next_tick() {
489        let time_wheel = Timewheel::new(Duration::from_millis(100));
490
491        let token = Token::next();
492        assert_eq!(
493            time_wheel.new_timer(token, Instant::now() + Duration::from_millis(100)),
494            Some(1)
495        );
496
497        sleep(Duration::from_millis(200));
498
499        assert_eq!(time_wheel.next_tick(), Some(vec![token]));
500    }
501}