use std::sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
};
use futures_util::task::{ArcWake, waker_ref};
use tokio::sync::Notify;
use wasmtime::{Memory, Store};
use selium_abi::{
GuestAtomicUint, GuestUint,
mailbox::{CAPACITY, FLAG_OFFSET, RING_OFFSET, TAIL_OFFSET},
};
pub struct GuestMailbox {
base: AtomicUsize,
closed: AtomicBool,
notify: Notify,
}
unsafe impl Send for GuestMailbox {}
unsafe impl Sync for GuestMailbox {}
impl GuestMailbox {
unsafe fn new<T>(memory: &Memory, store: &mut Store<T>) -> Self {
let base = memory.data_ptr(store) as usize;
Self {
base: AtomicUsize::new(base),
closed: AtomicBool::new(false),
notify: Notify::new(),
}
}
pub(crate) fn refresh_base(&self, base: usize) {
self.base.store(base, Ordering::Release);
}
pub(crate) fn close(&self) {
self.closed.store(true, Ordering::Release);
self.notify.notify_one();
}
pub(crate) fn is_closed(&self) -> bool {
self.closed.load(Ordering::Acquire)
}
fn ptrs(
&self,
) -> (
*const GuestAtomicUint,
*const GuestAtomicUint,
*const GuestAtomicUint,
) {
let base = self.base.load(Ordering::Acquire);
(
(base + FLAG_OFFSET) as *const _,
(base + TAIL_OFFSET) as *const _,
(base + RING_OFFSET) as *const _,
)
}
fn enqueue(&self, task_id: usize) {
if self.closed.load(Ordering::Acquire) {
return;
}
unsafe {
let (flag, tail_ptr, ring) = self.ptrs();
let tail = (*tail_ptr).fetch_add(1, Ordering::AcqRel);
let slot = (tail % CAPACITY) as usize;
let id = GuestUint::try_from(task_id).expect("task id exceeds guest width");
(*ring.add(slot)).store(id, Ordering::Relaxed);
(*flag).store(1, Ordering::Release);
#[cfg(target_os = "linux")]
{
libc::syscall(
libc::SYS_futex,
flag as *const GuestAtomicUint as libc::c_long,
libc::FUTEX_WAKE as libc::c_long,
1 as libc::c_long,
);
}
}
self.notify.notify_one();
}
pub(crate) fn is_signalled(&self) -> bool {
if self.closed.load(Ordering::Acquire) {
return false;
}
let (flag, _tail, _ring) = self.ptrs();
unsafe { (*flag).load(Ordering::Acquire) != 0 }
}
pub(crate) async fn wait_for_signal(&self) {
self.notify.notified().await;
}
pub(crate) fn waker(&'static self, task_id: usize) -> std::task::Waker {
struct MbWaker {
mb: &'static GuestMailbox,
id: usize,
}
impl ArcWake for MbWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.mb.enqueue(arc_self.id);
}
}
let arc = Arc::new(MbWaker {
mb: self,
id: task_id,
});
waker_ref(&arc).clone()
}
}
pub unsafe fn create_guest_mailbox<T>(
memory: &Memory,
store: &mut Store<T>,
) -> &'static GuestMailbox {
Box::leak(Box::new(unsafe { GuestMailbox::new(memory, store) }))
}
#[cfg(test)]
mod tests {
use selium_abi::mailbox::SLOT_SIZE;
use wasmtime::{Engine, MemoryType};
use super::*;
#[test]
fn enqueue_writes_ring_and_sets_flag() {
let engine = Engine::default();
let mut store = Store::new(&engine, ());
let memory = Memory::new(&mut store, MemoryType::new(1, None)).expect("memory");
{
let data = memory.data_mut(&mut store);
for slot in data
.iter_mut()
.take(RING_OFFSET + (CAPACITY as usize * SLOT_SIZE))
{
*slot = 0;
}
}
let mailbox = unsafe { GuestMailbox::new(&memory, &mut store) };
mailbox.enqueue(7);
let base = memory.data_ptr(&mut store) as usize;
let tail_ptr = (base + TAIL_OFFSET) as *const GuestAtomicUint;
let ring_ptr = (base + RING_OFFSET) as *const GuestAtomicUint;
let flag_ptr = (base + FLAG_OFFSET) as *const GuestAtomicUint;
let tail = unsafe { (*tail_ptr).load(Ordering::Relaxed) as usize };
assert_eq!(tail, 1);
let slot = unsafe { (*ring_ptr).load(Ordering::Relaxed) };
assert_eq!(slot, 7);
let flag = unsafe { (*flag_ptr).load(Ordering::Relaxed) };
assert_eq!(flag, 1);
}
}