use futures_task::{ArcWake, waker_ref};
use pin_project_lite::pin_project;
use crate::{
bindings::{
_vlib_node_registration, async_context, vl_api_force_rpc_call_main_thread,
vlib_helper_get_global_main, vlib_helper_process_node_loop,
vlib_helper_remove_node_from_registrations, vlib_main_t, vlib_node_registration_t,
vlib_node_runtime_t, vlib_process_signal_event_mt_args_t,
vlib_process_signal_event_mt_helper,
},
vlib::{
MainRef, NodeRuntimeRef,
node::{ErrorCounters, NextNodes},
process_node::tw_timer::{Timer, TimerWheel},
},
};
use std::{
cell::{RefCell, UnsafeCell},
ffi::c_void,
fmt,
future::Future,
pin::Pin,
rc::Rc,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
pub use futures_task::LocalFutureObj;
const TICK_INTERVAL_PER_MS: u64 = 1;
const TICK_INTERVAL_S: f64 = TICK_INTERVAL_PER_MS as f64 / 1000.0;
pub trait ProcessNode {
type NextNodes: NextNodes;
type RuntimeData: Send + Copy;
type Errors: ErrorCounters;
#[must_use = "Futures do nothing unless awaited"]
fn function(
&self,
vm: &mut MainRef,
node: &mut NodeRuntimeRef<Self>,
) -> impl Future<Output = ()>;
}
pub struct ProcessNodeRegistration<N: ProcessNode, const N_NEXT_NODES: usize> {
registration: UnsafeCell<_vlib_node_registration<[*mut std::os::raw::c_char; N_NEXT_NODES]>>,
_marker: std::marker::PhantomData<N>,
}
impl<N: ProcessNode, const N_NEXT_NODES: usize> ProcessNodeRegistration<N, N_NEXT_NODES> {
pub const fn new(
registration: _vlib_node_registration<[*mut std::os::raw::c_char; N_NEXT_NODES]>,
) -> Self {
Self {
registration: UnsafeCell::new(registration),
_marker: ::std::marker::PhantomData,
}
}
pub unsafe fn register(&'static self) {
unsafe {
let vgm = vlib_helper_get_global_main();
let reg = self.registration.get();
(*reg).next_registration = (*vgm).node_registrations;
(*vgm).node_registrations = reg as *mut vlib_node_registration_t;
}
}
pub unsafe fn unregister(&self) {
unsafe {
let vgm = vlib_helper_get_global_main();
vlib_helper_remove_node_from_registrations(
vgm,
self.registration.get() as *mut vlib_node_registration_t,
);
}
}
pub unsafe fn node_runtime_from_ptr<'a>(
&self,
ptr: *mut vlib_node_runtime_t,
) -> &'a mut NodeRuntimeRef<N> {
unsafe { NodeRuntimeRef::from_ptr_mut(ptr) }
}
}
unsafe impl<N: ProcessNode, const N_NEXT_NODES: usize> Send
for ProcessNodeRegistration<N, N_NEXT_NODES>
{
}
unsafe impl<N: ProcessNode, const N_NEXT_NODES: usize> Sync
for ProcessNodeRegistration<N, N_NEXT_NODES>
{
}
pub(crate) struct ProcessAsyncContextShared {
timer_wheel: Rc<RefCell<Box<TimerWheel>>>,
waker: Arc<ProcessAsyncContextWaker>,
start_time: Instant,
}
impl ProcessAsyncContextShared {
fn new(node_index: u32) -> Self {
let mut timer_wheel = Box::new_uninit();
TimerWheel::init(&mut timer_wheel);
let timer_wheel = unsafe { timer_wheel.assume_init() };
Self {
timer_wheel: Rc::new(RefCell::new(timer_wheel)),
waker: Arc::new(ProcessAsyncContextWaker { node_index }),
start_time: Instant::now(),
}
}
fn instant_to_ticks(&self, t: Instant) -> u64 {
let duration = t.saturating_duration_since(self.start_time);
duration
.as_millis()
.div_ceil(TICK_INTERVAL_PER_MS.into())
.try_into()
.unwrap_or(u64::MAX)
}
}
pin_project! {
pub struct ProcessAsyncContext<'a> {
main_ref: *mut vlib_main_t,
#[pin]
future: Option<LocalFutureObj<'a, ()>>,
shared: Rc<ProcessAsyncContextShared>,
}
}
impl<'a> ProcessAsyncContext<'a> {
pub fn new<N>(
vm: &'a mut MainRef,
node: &NodeRuntimeRef<N>,
future: LocalFutureObj<'a, ()>,
) -> Self {
Self {
main_ref: vm.as_ptr(),
future: Some(future),
shared: Rc::new(ProcessAsyncContextShared::new(node.node_index())),
}
}
pub fn run(mut self) -> ! {
unsafe {
vlib_helper_process_node_loop(
self.main_ref,
&mut self as *mut Self as *mut async_context,
)
}
}
}
struct ProcessAsyncContextWaker {
node_index: u32,
}
impl ArcWake for ProcessAsyncContextWaker {
fn wake_by_ref(arc_self: &std::sync::Arc<Self>) {
let mut args = vlib_process_signal_event_mt_args_t {
node_index: arc_self.node_index as u64,
type_opaque: 0,
data: 0,
};
unsafe {
vl_api_force_rpc_call_main_thread(
vlib_process_signal_event_mt_helper as *mut c_void,
std::ptr::addr_of_mut!(args) as *mut u8,
std::mem::size_of_val(&args) as u32,
)
};
}
}
#[unsafe(no_mangle)]
unsafe extern "C" fn vpp_plugin_rs_poll_async_coroutine(context: *mut ProcessAsyncContext) {
let mut ctx = unsafe { Pin::new_unchecked(&mut *context) };
let ticks_since_start = ctx.shared.instant_to_ticks(Instant::now());
ctx.shared
.timer_wheel
.borrow_mut()
.expire_timers(ticks_since_start);
let ctx_project = ctx.as_mut().project();
if let Some(fut) = ctx_project.future.as_pin_mut() {
ASYNC_CONTEXT.with(|tls_ctx| {
tls_ctx.replace(Some(ctx_project.shared.clone()));
});
let waker = waker_ref(&ctx_project.shared.waker);
let mut executor_context = Context::from_waker(&waker);
if matches!(fut.poll(&mut executor_context), Poll::Ready(_)) {
ctx.project().future.set(None);
}
ASYNC_CONTEXT.with(|tls_ctx| {
tls_ctx.replace(None);
});
}
}
#[unsafe(no_mangle)]
unsafe extern "C" fn vpp_plugin_rs_next_timer_duration(context: *mut ProcessAsyncContext) -> f64 {
let ctx = unsafe { &*context };
let next_expiration = ctx.shared.timer_wheel.borrow().next_expiration();
next_expiration
.map(|ticks| ticks as f64 * TICK_INTERVAL_S)
.unwrap_or(f64::MAX)
}
thread_local! {
static ASYNC_CONTEXT: RefCell<Option<Rc<ProcessAsyncContextShared>>> = const { RefCell::new(None) };
}
pub(crate) fn with_current_async_context<F, R>(f: F) -> R
where
F: FnOnce(&Rc<ProcessAsyncContextShared>) -> R,
{
ASYNC_CONTEXT.with(|ctx| {
f(ctx.borrow().as_ref().expect(
"There is no async context present - must be called from a vpp-plugin-rs process node",
))
})
}
pin_project! {
#[project(!Unpin)]
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Sleep {
#[pin]
entry: Timer,
}
}
impl Sleep {
pub(crate) fn new_timeout(deadline: Instant, ctx: &Rc<ProcessAsyncContextShared>) -> Self {
let deadline_ticks = ctx.instant_to_ticks(deadline);
let entry = Timer::new(ctx.timer_wheel.clone(), deadline_ticks);
Self { entry }
}
pub fn is_elapsed(&self) -> bool {
self.entry.is_ready()
}
}
impl Future for Sleep {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
self.project().entry.poll(cx)
}
}
pub fn sleep(duration: Duration) -> Sleep {
let deadline = Instant::now().checked_add(duration).unwrap_or_else(|| {
Instant::now() + Duration::from_secs(86400 * 365 * 30)
});
with_current_async_context(|ctx| Sleep::new_timeout(deadline, ctx))
}
#[derive(Debug, PartialEq, Eq)]
#[allow(missing_copy_implementations)]
pub struct Elapsed(());
impl fmt::Display for Elapsed {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
"deadline has elapsed".fmt(fmt)
}
}
impl std::error::Error for Elapsed {}
impl From<Elapsed> for std::io::Error {
fn from(_err: Elapsed) -> std::io::Error {
std::io::ErrorKind::TimedOut.into()
}
}
pub fn timeout<F>(duration: Duration, future: F) -> Timeout<F::IntoFuture>
where
F: IntoFuture,
{
let delay = sleep(duration);
Timeout::new_with_delay(future.into_future(), delay)
}
pin_project! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
pub struct Timeout<T> {
#[pin]
value: T,
#[pin]
delay: Sleep,
}
}
impl<T> Timeout<T> {
pub(crate) fn new_with_delay(value: T, delay: Sleep) -> Timeout<T> {
Timeout { value, delay }
}
pub fn get_ref(&self) -> &T {
&self.value
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.value
}
pub fn into_inner(self) -> T {
self.value
}
}
impl<T> Future for Timeout<T>
where
T: Future,
{
type Output = Result<T::Output, Elapsed>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = self.project();
if let Poll::Ready(v) = me.value.poll(cx) {
return Poll::Ready(Ok(v));
}
match me.delay.poll(cx) {
Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::{Elapsed, sleep};
use std::time::Duration;
#[test]
#[should_panic(
expected = "There is no async context present - must be called from a vpp-plugin-rs process node"
)]
fn sleep_outside_process_node_panics() {
std::mem::drop(sleep(Duration::from_secs(1)));
}
#[test]
fn elapsed_to_std_error() {
let e: std::io::Error = Elapsed(()).into();
assert_eq!(e.kind(), std::io::ErrorKind::TimedOut);
}
}