1use std::sync::{
26 Arc,
27 atomic::{AtomicBool, AtomicUsize, Ordering},
28};
29
30use futures_util::task::{ArcWake, waker_ref};
31use tokio::sync::Notify;
32use wasmtime::{Memory, Store};
33
34use selium_abi::{
35 GuestAtomicUint, GuestUint,
36 mailbox::{CAPACITY, FLAG_OFFSET, RING_OFFSET, TAIL_OFFSET},
37};
38
39pub struct GuestMailbox {
45 base: AtomicUsize,
46 closed: AtomicBool,
47 notify: Notify,
48}
49
50unsafe impl Send for GuestMailbox {}
51
52unsafe impl Sync for GuestMailbox {}
53
54impl GuestMailbox {
55 unsafe fn new<T>(memory: &Memory, store: &mut Store<T>) -> Self {
60 let base = memory.data_ptr(store) as usize;
61 Self {
62 base: AtomicUsize::new(base),
63 closed: AtomicBool::new(false),
64 notify: Notify::new(),
65 }
66 }
67
68 pub(crate) fn refresh_base(&self, base: usize) {
70 self.base.store(base, Ordering::Release);
71 }
72
73 pub(crate) fn close(&self) {
75 self.closed.store(true, Ordering::Release);
76 self.notify.notify_one();
77 }
78
79 pub(crate) fn is_closed(&self) -> bool {
81 self.closed.load(Ordering::Acquire)
82 }
83
84 fn ptrs(
85 &self,
86 ) -> (
87 *const GuestAtomicUint,
88 *const GuestAtomicUint,
89 *const GuestAtomicUint,
90 ) {
91 let base = self.base.load(Ordering::Acquire);
92 (
93 (base + FLAG_OFFSET) as *const _,
94 (base + TAIL_OFFSET) as *const _,
95 (base + RING_OFFSET) as *const _,
96 )
97 }
98
99 fn enqueue(&self, task_id: usize) {
101 if self.closed.load(Ordering::Acquire) {
102 return;
103 }
104 unsafe {
105 let (flag, tail_ptr, ring) = self.ptrs();
106 let tail = (*tail_ptr).fetch_add(1, Ordering::AcqRel);
107 let slot = (tail % CAPACITY) as usize;
108 let id = GuestUint::try_from(task_id).expect("task id exceeds guest width");
109 (*ring.add(slot)).store(id, Ordering::Relaxed);
110 (*flag).store(1, Ordering::Release);
111 #[cfg(target_os = "linux")]
112 {
113 libc::syscall(
114 libc::SYS_futex,
115 flag as *const GuestAtomicUint as libc::c_long,
116 libc::FUTEX_WAKE as libc::c_long,
117 1 as libc::c_long,
118 );
119 }
120 }
121 self.notify.notify_one();
122 }
123
124 pub(crate) fn is_signalled(&self) -> bool {
126 if self.closed.load(Ordering::Acquire) {
127 return false;
128 }
129 let (flag, _tail, _ring) = self.ptrs();
130 unsafe { (*flag).load(Ordering::Acquire) != 0 }
131 }
132
133 pub(crate) async fn wait_for_signal(&self) {
135 self.notify.notified().await;
136 }
137
138 pub(crate) fn waker(&'static self, task_id: usize) -> std::task::Waker {
140 struct MbWaker {
141 mb: &'static GuestMailbox,
142 id: usize,
143 }
144 impl ArcWake for MbWaker {
145 fn wake_by_ref(arc_self: &Arc<Self>) {
146 arc_self.mb.enqueue(arc_self.id);
147 }
148 }
149 let arc = Arc::new(MbWaker {
150 mb: self,
151 id: task_id,
152 });
153 waker_ref(&arc).clone()
154 }
155}
156
157pub unsafe fn create_guest_mailbox<T>(
160 memory: &Memory,
161 store: &mut Store<T>,
162) -> &'static GuestMailbox {
163 Box::leak(Box::new(unsafe { GuestMailbox::new(memory, store) }))
164}
165
166#[cfg(test)]
167mod tests {
168 use selium_abi::mailbox::SLOT_SIZE;
169 use wasmtime::{Engine, MemoryType};
170
171 use super::*;
172
173 #[test]
174 fn enqueue_writes_ring_and_sets_flag() {
175 let engine = Engine::default();
176 let mut store = Store::new(&engine, ());
177 let memory = Memory::new(&mut store, MemoryType::new(1, None)).expect("memory");
178
179 {
181 let data = memory.data_mut(&mut store);
182 for slot in data
183 .iter_mut()
184 .take(RING_OFFSET + (CAPACITY as usize * SLOT_SIZE))
185 {
186 *slot = 0;
187 }
188 }
189
190 let mailbox = unsafe { GuestMailbox::new(&memory, &mut store) };
191 mailbox.enqueue(7);
192
193 let base = memory.data_ptr(&mut store) as usize;
194 let tail_ptr = (base + TAIL_OFFSET) as *const GuestAtomicUint;
195 let ring_ptr = (base + RING_OFFSET) as *const GuestAtomicUint;
196 let flag_ptr = (base + FLAG_OFFSET) as *const GuestAtomicUint;
197
198 let tail = unsafe { (*tail_ptr).load(Ordering::Relaxed) as usize };
199 assert_eq!(tail, 1);
200 let slot = unsafe { (*ring_ptr).load(Ordering::Relaxed) };
201 assert_eq!(slot, 7);
202 let flag = unsafe { (*flag_ptr).load(Ordering::Relaxed) };
203 assert_eq!(flag, 1);
204 }
205}