use std::{
any::TypeId,
pin::{Pin, pin},
task::{Context, Poll},
};
use slab::Slab;
use tempest_io::Io;
use crate::{
context::{CURRENT_CONTEXT, RuntimeContext, TaskId, WakeSets, make_waker, parse_op_handle},
task::Tasks,
};
pub struct Runtime<I: Io> {
io: I,
tasks: Tasks,
finished_tasks: Vec<usize>,
wake_sets: WakeSets,
next_op: u64,
}
impl<I: Io> Runtime<I> {
pub fn new(io: I) -> Self {
Self {
io,
tasks: Slab::new(),
finished_tasks: Vec::new(),
wake_sets: WakeSets::default(),
next_op: 0,
}
}
pub fn inspect_io(&mut self) -> &mut I {
&mut self.io
}
fn wake_active_by_io_completions(&mut self) {
for (handle, _) in self.io.completions() {
let (task_id, _) = parse_op_handle(*handle);
self.wake_sets.active.insert(task_id);
}
}
pub fn tick<F: Future>(&mut self, fut: &mut Pin<&mut F>) -> Poll<F::Output> {
self.wake_sets.swap();
let ctx = RuntimeContext {
type_id: TypeId::of::<I>(),
io: &mut self.io as *mut I as *mut (),
tasks: &mut self.tasks as *mut _,
wake_sets: &mut self.wake_sets as *mut _,
next_op: &mut self.next_op as *mut _,
};
CURRENT_CONTEXT.set(Some(ctx));
self.io.poll().expect("fatal: io poll failed");
self.wake_active_by_io_completions();
assert!(
!self.wake_sets.active.is_empty() || self.io.in_flight() > 0,
"deadlock: wake set is empty and no I/O in flight"
);
if self.wake_sets.active.is_empty() {
self.io.park().expect("fatal: io park failed");
self.wake_active_by_io_completions();
}
let mut result = Poll::Pending;
let mut active = std::mem::take(&mut self.wake_sets.active);
for &task in &active {
let waker = make_waker(task);
let mut cx = Context::from_waker(&waker);
match task {
TaskId::Main => {
result = fut.as_mut().poll(&mut cx);
}
TaskId::Task(id) => {
let index = id.get() as usize;
if let Poll::Ready(()) = self.tasks[index].as_mut().poll(&mut cx) {
self.finished_tasks.push(index);
}
}
}
}
active.clear();
self.wake_sets.active = active;
for key in self.finished_tasks.drain(..) {
let _ = self.tasks.remove(key);
}
assert!(self.io.completions().is_empty(), "leaked io completions");
CURRENT_CONTEXT.set(None);
result
}
pub fn block_on<F: Future>(&mut self, fut: F) -> F::Output {
let mut fut = pin!(fut);
self.wake_sets.staging.insert(TaskId::Main);
loop {
if let Poll::Ready(value) = self.tick(&mut fut) {
return value;
}
}
}
}
impl<I> Default for Runtime<I>
where
I: Io + Default,
{
fn default() -> Self {
Self::new(I::default())
}
}
pub fn block_on<I: Io, F: Future>(io: I, fut: F) -> F::Output {
Runtime::new(io).block_on(fut)
}
#[cfg(test)]
mod tests {
use super::*;
use tempest_io::VirtualIo;
#[test]
fn immediate_ready() {
let io = VirtualIo::default();
let result = block_on(io, async { 42 });
assert_eq!(result, 42);
}
#[test]
fn multiple_ticks() {
let io = VirtualIo::default();
let mut polls = 0u32;
let result = block_on(
io,
std::future::poll_fn(|cx| {
polls += 1;
if polls >= 3 {
Poll::Ready("done")
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}),
);
assert_eq!(result, "done");
assert_eq!(polls, 3);
}
}