use std::{
io,
ops::Deref,
sync::{
atomic::{AtomicU64, Ordering},
Arc, OnceLock,
},
task::Waker,
time::{Duration, Instant},
};
use dashmap::DashMap;
use mio::{
event::{self, Source},
Interest, Token,
};
use rasi_syscall::{CancelablePoll, Handle};
pub(crate) struct MioSocket<S: Source> {
pub(crate) token: Token,
pub(crate) socket: S,
}
impl<S: Source> From<(Token, S)> for MioSocket<S> {
fn from(value: (Token, S)) -> Self {
Self {
token: value.0,
socket: value.1,
}
}
}
impl<S: Source> Deref for MioSocket<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.socket
}
}
impl<S: Source> Drop for MioSocket<S> {
fn drop(&mut self) {
if global_reactor().deregister(&mut self.socket).is_err() {}
}
}
pub(crate) fn would_block<T, F>(
token: Token,
waker: Waker,
interests: Interest,
mut f: F,
) -> CancelablePoll<io::Result<T>>
where
F: FnMut() -> io::Result<T>,
{
global_reactor().once(token, interests, waker);
loop {
match f() {
Ok(t) => {
return {
global_reactor().remove_listeners(token, interests);
CancelablePoll::Ready(Ok(t))
}
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
return CancelablePoll::Pending(Handle::new(()));
}
Err(err) if err.kind() == io::ErrorKind::Interrupted => {
continue;
}
Err(err) => {
global_reactor().remove_listeners(token, interests);
return CancelablePoll::Ready(Err(err));
}
}
}
}
struct Timewheel {
tick_interval: u64,
timers: DashMap<u64, boxcar::Vec<Token>>,
ticks: AtomicU64,
start_instant: Instant,
timer_count: AtomicU64,
}
impl Timewheel {
fn new(tick_interval: Duration) -> Self {
Self {
tick_interval: tick_interval.as_micros() as u64,
ticks: Default::default(),
start_instant: Instant::now(),
timer_count: Default::default(),
timers: Default::default(),
}
}
#[allow(unused)]
fn timers(&self) -> u64 {
self.timer_count.load(Ordering::Relaxed)
}
pub fn new_timer(&self, token: Token, deadline: Instant) -> Option<u64> {
let ticks = (deadline - self.start_instant).as_micros() as u64 / self.tick_interval;
let ticks = ticks as u64;
if self
.ticks
.fetch_update(Ordering::Release, Ordering::Acquire, |current| {
if current > ticks {
None
} else {
Some(current)
}
})
.is_err()
{
return None;
}
self.timers.entry(ticks).or_default().push(token);
if self.ticks.load(Ordering::SeqCst) > ticks {
if self.timers.remove(&ticks).is_some() {
return None;
}
}
if self
.ticks
.fetch_update(Ordering::Release, Ordering::Acquire, |current| {
if current > ticks {
if self.timers.remove(&ticks).is_some() {
return None;
} else {
return Some(current);
}
} else {
return Some(current);
}
})
.is_err()
{
return None;
}
self.timer_count.fetch_add(1, Ordering::SeqCst);
Some(ticks)
}
pub fn next_tick(&self) -> Option<Vec<Token>> {
loop {
let current = self.ticks.load(Ordering::Acquire);
let instant_duration = Instant::now() - self.start_instant;
let ticks = instant_duration.as_micros() as u64 / self.tick_interval;
assert!(current <= ticks);
if current == ticks {
return None;
}
if self
.ticks
.compare_exchange(current, ticks, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let mut timeout_timers = vec![];
for i in current..ticks {
if let Some((_, queue)) = self.timers.remove(&i) {
for t in queue.into_iter() {
timeout_timers.push(t);
}
}
}
return Some(timeout_timers);
}
}
}
}
pub struct Reactor {
mio_registry: mio::Registry,
read_op_wakers: DashMap<Token, Waker>,
write_op_wakers: DashMap<Token, Waker>,
timewheel: Timewheel,
}
pub type ArcReactor = Arc<Reactor>;
impl Reactor {
fn new(tick_interval: Duration) -> io::Result<ArcReactor> {
let mio_poll = mio::Poll::new()?;
let mio_registry = mio_poll.registry().try_clone()?;
let reactor = Arc::new(Reactor {
mio_registry,
read_op_wakers: Default::default(),
write_op_wakers: Default::default(),
timewheel: Timewheel::new(tick_interval),
});
let background = ReactorBackground::new(tick_interval, mio_poll, reactor.clone());
background.start();
Ok(reactor)
}
pub fn register<S>(&self, source: &mut S, token: Token, interests: Interest) -> io::Result<()>
where
S: event::Source + ?Sized,
{
self.mio_registry.register(source, token, interests)
}
pub fn deregister<S>(&self, source: &mut S) -> io::Result<()>
where
S: event::Source + ?Sized,
{
self.mio_registry.deregister(source)
}
pub fn deadline(&self, token: Token, waker: Waker, deadline: Instant) -> Option<u64> {
self.write_op_wakers.insert(token, waker);
if let Some(id) = self.timewheel.new_timer(token, deadline) {
Some(id)
} else {
self.write_op_wakers.remove(&token);
None
}
}
pub fn once(&self, token: Token, interests: Interest, waker: Waker) {
if interests.is_readable() {
self.read_op_wakers.insert(token, waker.clone());
}
if interests.is_writable() {
self.write_op_wakers.insert(token, waker);
}
}
pub fn notify(&self, token: Token, interests: Interest) {
if interests.is_readable() {
if let Some(waker) = self.read_op_wakers.remove(&token).map(|(_, v)| v) {
waker.wake();
}
}
if interests.is_writable() {
if let Some(waker) = self.write_op_wakers.remove(&token).map(|(_, v)| v) {
waker.wake();
}
}
}
pub fn remove_listeners(&self, token: Token, interests: Interest) {
if interests.is_readable() {
self.read_op_wakers.remove(&token);
}
if interests.is_writable() {
self.write_op_wakers.remove(&token);
}
}
}
struct ReactorBackground {
mio_poll: mio::Poll,
reactor: ArcReactor,
tick_interval: Duration,
}
impl ReactorBackground {
fn new(tick_interval: Duration, mio_poll: mio::Poll, reactor: ArcReactor) -> Self {
Self {
mio_poll,
reactor,
tick_interval,
}
}
fn start(mut self) {
std::thread::spawn(move || {
self.dispatch_loop();
});
}
fn dispatch_loop(&mut self) {
let mut events = mio::event::Events::with_capacity(1024);
loop {
self.mio_poll
.poll(&mut events, Some(self.tick_interval))
.expect("Mio poll panic");
for event in &events {
if event.is_readable() {
self.notify(event.token(), Interest::READABLE);
}
if event.is_writable() {
self.notify(event.token(), Interest::WRITABLE);
}
}
let timeout_timers = self.reactor.timewheel.next_tick();
if let Some(timeout_timers) = timeout_timers {
for token in timeout_timers {
self.notify(token, Interest::WRITABLE);
}
}
}
}
fn notify(&self, token: Token, interests: Interest) {
self.reactor.notify(token, interests);
}
}
static GLOBAL_REACTOR: OnceLock<ArcReactor> = OnceLock::new();
pub fn start_reactor_with(tick_interval: Duration) {
if GLOBAL_REACTOR
.set(Reactor::new(tick_interval).unwrap())
.is_err()
{
panic!("Call start_reactor_with twice.");
}
}
pub fn global_reactor() -> ArcReactor {
GLOBAL_REACTOR
.get_or_init(|| Reactor::new(Duration::from_millis(10)).unwrap())
.clone()
}
#[cfg(test)]
mod tests {
use std::{sync::Barrier, thread::sleep, time::Duration};
use crate::TokenSequence;
use super::*;
#[test]
fn test_add_timers() {
let threads = 10;
let loops = 3usize;
let time_wheel = Arc::new(Timewheel::new(Duration::from_millis(100)));
let barrier = Arc::new(Barrier::new(threads));
let mut handles = vec![];
for _ in 0..threads {
let barrier = barrier.clone();
let time_wheel = time_wheel.clone();
handles.push(std::thread::spawn(move || {
barrier.wait();
for i in 0..loops {
time_wheel
.new_timer(
Token::next(),
Instant::now() + Duration::from_secs((i + 1) as u64),
)
.unwrap();
}
}))
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(time_wheel.timers() as usize, threads * loops);
let mut handles = vec![];
let counter = Arc::new(AtomicU64::new(0));
for _ in 0..threads {
let time_wheel = time_wheel.clone();
let counter = counter.clone();
handles.push(std::thread::spawn(move || loop {
if let Some(timers) = time_wheel.next_tick() {
counter.fetch_add(timers.len() as u64, Ordering::SeqCst);
}
if counter.load(Ordering::SeqCst) == (threads * loops) as u64 {
break;
}
}))
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_next_tick() {
let time_wheel = Timewheel::new(Duration::from_millis(100));
let token = Token::next();
assert_eq!(
time_wheel.new_timer(token, Instant::now() + Duration::from_millis(100)),
Some(1)
);
sleep(Duration::from_millis(200));
assert_eq!(time_wheel.next_tick(), Some(vec![token]));
}
}