1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
use std::{
sync::atomic::AtomicPtr,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use dashmap::DashMap;
use fnv::FnvBuildHasher;
use parking_lot::{Condvar, Mutex};
use thiserror::Error;
/// Error that can occur during wait/notify calls.
// Non-exhaustive to allow for future variants without breaking changes!
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum WaiterError {
/// Wait/Notify is not implemented for this memory
Unimplemented,
/// To many waiter for an address
TooManyWaiters,
/// Atomic operations are disabled.
AtomicsDisabled,
}
impl std::fmt::Display for WaiterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WaiterError")
}
}
/// Expected value for atomic waits
pub enum ExpectedValue {
/// No expected value; this is used for native waits only.
None,
/// 32-bit expected value
U32(u32),
/// 64-bit expected value
U64(u64),
}
/// A location in memory for a Waiter
#[derive(Clone, Copy, Debug)]
pub struct NotifyLocation {
/// The address of the Waiter location
pub address: u32,
/// The base of the memory this address is relative to
pub memory_base: *mut u8,
}
#[derive(Debug, Default)]
struct NotifyMap {
/// If set to true, all waits will fail with an error.
closed: AtomicBool,
// For each wait address, we store a mutex and a condvar. The condvar is
// used to handle sleeping and waking, while the mutex stores the
// (manually-updated) number of waiters on that address. This lets us
// know when there are no more waiters so we can clean up the map entry.
// note that using a Weak here would be insufficient since it can't
// clean up the map entries for us, only the mutexes/condvars.
map: DashMap<u32, Arc<(Mutex<u32>, Condvar)>, FnvBuildHasher>,
}
/// HashMap of Waiters for the Thread/Notify opcodes
#[derive(Debug)]
pub struct ThreadConditions {
inner: Arc<NotifyMap>, // The Hasmap with the Notify for the Notify/wait opcodes
}
impl Clone for ThreadConditions {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl ThreadConditions {
/// Create a new ThreadConditions
pub fn new() -> Self {
Self {
inner: Arc::new(NotifyMap::default()),
}
}
// To implement Wait / Notify, a HasMap, behind a mutex, will be used
// to track the address of waiter. The key of the hashmap is based on the memory.
// The actual waiting is implemented with a Condvar + Mutex pair. A Weak is stored
// in the hashmap to at least delete the condvar and mutex when there are no
// waiters for a given address. Map keys are currently not cleaned up.
/// Add current thread to the waiter hash
///
/// # Safety
/// If `expected` is [`ExpectedValue::None`], no safety requirements.
/// The notify location must have a valid base address that belongs to a memory,
/// and the address must be a valid offset within that memory. The offset also
/// must be properly aligned for the expected value type; either 4-byte aligned for
/// [`ExpectedValue::U32`] or 8-byte aligned for [`ExpectedValue::U64`].
pub unsafe fn do_wait(
&mut self,
dst: NotifyLocation,
expected: ExpectedValue,
timeout: Option<Duration>,
) -> Result<u32, WaiterError> {
if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
return Err(WaiterError::AtomicsDisabled);
}
if self.inner.map.len() as u64 >= 1u64 << 32 {
return Err(WaiterError::TooManyWaiters);
}
// Step 1: lock the map key, so we know no one else can get/create a
// different Arc than the one we're getting/creating
let entry = self.inner.map.entry(dst.address);
let ref_mut = entry.or_default();
let arc = ref_mut.clone();
// Step 2: lock the mutex while still holding the map lock, so nobody
// can delete the map key or make a new Arc
let mut mutex_guard = arc.0.lock();
// Step 3: unlock the map key, we don't need it anymore.
drop(ref_mut);
// Once we lock the mutex, we can check the expected value. A notifying
// thread will have written an updated value to the address *before*
// doing the notify call, and the call has to acquire the same lock we're
// holding. This means we can't miss an update to the expected value that
// would prevent us from sleeping.
// This logic mirrors how the linux kernel's futex syscall works, so see
// the documentation on that if I made zero sense here.
// Safety: the function's safety contract ensures that the memory location is valid
// and can be dereferenced.
let should_sleep = match expected {
ExpectedValue::None => true,
ExpectedValue::U32(expected_val) => unsafe {
let src = dst.memory_base.offset(dst.address as isize) as *mut u32;
let atomic_src = AtomicPtr::new(src);
let read_val = *atomic_src.load(Ordering::Acquire);
read_val == expected_val
},
ExpectedValue::U64(expected_val) => unsafe {
let src = dst.memory_base.offset(dst.address as isize) as *mut u64;
let atomic_src = AtomicPtr::new(src);
let read_val = *atomic_src.load(Ordering::Acquire);
read_val == expected_val
},
};
let ret = if should_sleep {
*mutex_guard += 1;
let ret = if let Some(timeout) = timeout {
let timeout = arc.1.wait_for(&mut mutex_guard, timeout);
if timeout.timed_out() {
2 // timeout
} else {
0 // notified
}
} else {
arc.1.wait(&mut mutex_guard);
0
};
*mutex_guard -= 1;
ret
} else {
1 // value mismatch
};
{
// Note we use two sets of locks; one for the map itself, and one per
// wait address. Locking order must stay consistent at all times: map
// first, then mutex. So we have to drop the mutex guard here and then
// reacquire it after locking the map key to avoid deadlocks.
drop(mutex_guard);
// Same as above, first lock the map key...
let entry = self.inner.map.entry(dst.address);
if let dashmap::Entry::Occupied(occupied) = entry {
// ... then lock the mutex.
let arc = occupied.get().clone();
let mutex_guard = arc.0.lock();
if *mutex_guard == 0 {
// No more waiters, remove the map entry.
occupied.remove();
}
}
}
Ok(ret)
}
/// Notify waiters from the wait list
pub fn do_notify(&mut self, dst: u32, count: u32) -> u32 {
let mut count_token = 0u32;
if let Some(v) = self.inner.map.get(&dst) {
let mutex_guard = v.0.lock();
for _ in 0..count {
if !v.1.notify_one() {
break;
}
count_token += 1;
}
drop(mutex_guard);
}
count_token
}
/// Wake all the waiters, *without* marking them as notified.
///
/// Useful on shutdown to resume execution in all waiters.
pub fn wake_all_atomic_waiters(&self) {
for item in self.inner.map.iter_mut() {
let arc = item.value();
let _mutex_guard = arc.0.lock();
arc.1.notify_all();
}
}
/// Disable the use of atomics, leading to all atomic waits failing with
/// an error, which leads to a Webassembly trap.
///
/// Useful for force-closing instances that keep waiting on atomics.
pub fn disable_atomics(&self) {
self.inner
.closed
.store(true, std::sync::atomic::Ordering::Release);
self.wake_all_atomic_waiters();
}
/// Get a weak handle to this `ThreadConditions` instance.
///
/// See [`ThreadConditionsHandle`] for more information.
pub fn downgrade(&self) -> ThreadConditionsHandle {
ThreadConditionsHandle {
inner: Arc::downgrade(&self.inner),
}
}
}
/// A weak handle to a `ThreadConditions` instance, which does not prolong its
/// lifetime.
///
/// Internally holds a [`std::sync::Weak`] pointer.
pub struct ThreadConditionsHandle {
inner: std::sync::Weak<NotifyMap>,
}
impl ThreadConditionsHandle {
/// Attempt to upgrade this handle to a strong reference.
///
/// Returns `None` if the original `ThreadConditions` instance has been dropped.
pub fn upgrade(&self) -> Option<ThreadConditions> {
self.inner.upgrade().map(|inner| ThreadConditions { inner })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn threadconditions_notify_nowaiters() {
let mut conditions = ThreadConditions::new();
let ret = conditions.do_notify(0, 1);
assert_eq!(ret, 0);
}
#[test]
fn threadconditions_notify_1waiter() {
use std::thread;
let mut conditions = ThreadConditions::new();
let mut threadcond = conditions.clone();
thread::spawn(move || {
let dst = NotifyLocation {
address: 0,
memory_base: std::ptr::null_mut(),
};
let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None) }.unwrap();
assert_eq!(ret, 0);
});
thread::sleep(Duration::from_millis(10));
let ret = conditions.do_notify(0, 1);
assert_eq!(ret, 1);
}
#[test]
fn threadconditions_notify_waiter_timeout() {
use std::thread;
let mut conditions = ThreadConditions::new();
let mut threadcond = conditions.clone();
thread::spawn(move || {
let dst = NotifyLocation {
address: 0,
memory_base: std::ptr::null_mut(),
};
let ret = unsafe {
threadcond
.do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(1)))
.unwrap()
};
assert_eq!(ret, 2);
});
thread::sleep(Duration::from_millis(50));
let ret = conditions.do_notify(0, 1);
assert_eq!(ret, 0);
}
#[test]
fn threadconditions_notify_waiter_mismatch() {
use std::thread;
let mut conditions = ThreadConditions::new();
let mut threadcond = conditions.clone();
thread::spawn(move || {
let dst = NotifyLocation {
address: 8,
memory_base: std::ptr::null_mut(),
};
let ret = unsafe {
threadcond
.do_wait(dst, ExpectedValue::None, Some(Duration::from_millis(10)))
.unwrap()
};
assert_eq!(ret, 2);
});
thread::sleep(Duration::from_millis(1));
let ret = conditions.do_notify(0, 1);
assert_eq!(ret, 0);
thread::sleep(Duration::from_millis(100));
}
#[test]
fn threadconditions_notify_2waiters() {
use std::thread;
let mut conditions = ThreadConditions::new();
let mut threadcond = conditions.clone();
let mut threadcond2 = conditions.clone();
thread::spawn(move || {
let dst = NotifyLocation {
address: 0,
memory_base: std::ptr::null_mut(),
};
let ret = unsafe { threadcond.do_wait(dst, ExpectedValue::None, None).unwrap() };
assert_eq!(ret, 0);
});
thread::spawn(move || {
let dst = NotifyLocation {
address: 0,
memory_base: std::ptr::null_mut(),
};
let ret = unsafe { threadcond2.do_wait(dst, ExpectedValue::None, None).unwrap() };
assert_eq!(ret, 0);
});
thread::sleep(Duration::from_millis(20));
let ret = conditions.do_notify(0, 5);
assert_eq!(ret, 2);
}
#[test]
fn threadconditions_value_mismatch() {
let mut conditions = ThreadConditions::new();
let mut data: u32 = 42;
let dst = NotifyLocation {
address: 0,
memory_base: (&mut data as *mut u32) as *mut u8,
};
let ret = unsafe {
conditions
.do_wait(dst, ExpectedValue::U32(85), Some(Duration::from_millis(10)))
.unwrap()
};
assert_eq!(ret, 1);
}
}