nio 0.1.4

Async runtime for Rust
Documentation
use crate::{
    LocalContext, RuntimeContext,
    driver::{self, Driver},
    rt::{context::NioContext, task::LocalScheduler, task_queue::TaskQueue},
};
use nio_task::Status;
use std::{
    io,
    ops::ControlFlow,
    rc::Rc,
    sync::Arc,
    task::{Context, Poll, Waker},
    time::Duration,
};

pub struct EventLoop {
    tick: u32,
    driver: Driver,
    pub local_ctx: Rc<LocalContext>,
}

impl EventLoop {
    pub fn new(
        id: u8,
        driver: driver::Driver,
        runtime_ctx: Arc<RuntimeContext>,
        tick: u32,
        local_queue_cap: usize,
    ) -> Self {
        let worker_id = runtime_ctx.workers.id(id);
        let io_registry = driver.registry_owned().unwrap();
        let local_ctx = LocalContext::new(worker_id, local_queue_cap, runtime_ctx, io_registry);
        local_ctx.clone().init();
        EventLoop {
            tick,
            driver,
            local_ctx,
        }
    }

    pub fn run_until<Fut: Future>(&mut self, fut: Fut) -> Fut::Output {
        let (task, jh) = unsafe {
            LocalScheduler::spawn(
                self.local_ctx.worker_id,
                self.local_ctx.runtime_ctx.clone(),
                fut,
            )
        };

        let task_id = task.id();
        self.local_ctx.add_task_to_local_queue(task);

        self.run_with(|this, task_queue| {
            for _ in 0..this.tick {
                let Some(task) = (unsafe { this.local_ctx.local_queue(|q| q.pop_front()) }) else {
                    break;
                };
                match task.poll() {
                    Status::Yielded(task) => {
                        unsafe { this.local_ctx.local_queue(|q| q.push_back(task)) };
                    }
                    Status::Pending => {
                        let counter = task_queue.decrease_local();
                        this.local_ctx
                            .move_tasks_from_shared_to_local_queue(counter);
                    }
                    Status::Complete(meta) => {
                        let counter = task_queue.decrease_local();
                        this.local_ctx
                            .move_tasks_from_shared_to_local_queue(counter);

                        if meta.id() == task_id {
                            return ControlFlow::Break(());
                        }
                    }
                }
            }
            ControlFlow::Continue(())
        });

        let jh = std::pin::pin!(jh);
        match jh.poll(&mut Context::from_waker(Waker::noop())) {
            Poll::Ready(result) => result.unwrap(),
            Poll::Pending => unreachable!(),
        }
    }

    pub fn run(&mut self) {
        self.run_with(Self::execute_tasks);
    }

    pub fn execute_tasks(&self, task_queue: &TaskQueue) -> ControlFlow<(), ()> {
        for _ in 0..self.tick {
            let Some(task) = (unsafe { self.local_ctx.local_queue(|q| q.pop_front()) }) else {
                break;
            };
            match task.poll() {
                Status::Yielded(task) => {
                    unsafe { self.local_ctx.local_queue(|q| q.push_back(task)) };
                }
                Status::Pending | Status::Complete(_) => {
                    let counter = task_queue.decrease_local();
                    self.local_ctx
                        .move_tasks_from_shared_to_local_queue(counter);
                }
            }
        }
        ControlFlow::Continue(())
    }

    pub fn run_with(&mut self, process_tasks: impl Fn(&Self, &TaskQueue) -> ControlFlow<(), ()>) {
        let task_queue = self.local_ctx.task_queue();

        loop {
            match process_tasks(self, task_queue) {
                ControlFlow::Break(val) => return val,
                ControlFlow::Continue(_) => {}
            }

            let expired_timers = unsafe {
                self.local_ctx
                    .timers(|timer| timer.fetch(timer.clock.now()))
            };
            expired_timers.notify_all();

            let mut local_queue_is_empty = unsafe { self.local_ctx.local_queue(|q| q.is_empty()) };

            let counter = if local_queue_is_empty {
                // Accept notification from other threads.
                let (_notify_flag_removed, state) =
                    task_queue.accept_notify_once_if_shared_queue_is_empty();

                #[cfg(feature = "metrics")]
                if _notify_flag_removed {
                    self.local_ctx
                        .runtime_ctx
                        .measurement
                        .queue_drained(self.local_ctx.worker_id.get());
                }

                state
            } else {
                task_queue.load()
            };

            if counter.shared_queue_has_data() {
                self.local_ctx
                    .move_tasks_from_shared_to_local_queue(counter);
                local_queue_is_empty = false
            }

            let timeout = unsafe {
                self.local_ctx.timers(|timer| {
                    if local_queue_is_empty {
                        // No immediate work; Sleep until the next timer fires,
                        // or until woken by an I/O event or another thread send more task.
                        return timer.next_timeout(timer.clock.current());
                    }
                    // Do not sleep; We have more work to do.
                    Some(Duration::ZERO)
                })
            };

            // `driver.poll` method clear wake up notifications.
            let events = match self.driver.poll(timeout) {
                Ok(events) => events,
                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
                #[cfg(target_os = "wasi")]
                Err(e) if e.kind() == io::ErrorKind::InvalidInput => {
                    // In case of wasm32_wasi this error happens, when trying to poll without subscriptions
                    // just return from the park, as there would be nothing, which wakes us up.
                    continue;
                }
                Err(e) => panic!("unexpected error when polling the I/O driver: {e:?}"),
            };

            for event in events {
                if Driver::has_woken(event) {
                    continue;
                }
                let ptr = driver::IoWaker::from(event.token().0);
                unsafe { (*ptr).notify(event) };
            }
        }
    }
}

impl Drop for EventLoop {
    fn drop(&mut self) {
        NioContext::drop_local_context();
    }
}