use crate::runtime::vm::VMStore;
use core::cell::Cell;
use core::mem;
use core::ptr::NonNull;
std::thread_local! {
static STORAGE: Cell<Option<NonNull<SetStorage>>> = const { Cell::new(None) };
}
enum SetStorage {
Present(NonNull<dyn VMStore>),
Taken,
}
pub fn set<R>(store: &mut dyn VMStore, f: impl FnOnce() -> R) -> R {
let mut storage = SetStorage::Present(NonNull::from(store));
let _reset = ResetTls(STORAGE.with(|s| s.replace(Some(NonNull::from(&mut storage)))));
return f();
struct ResetTls(Option<NonNull<SetStorage>>);
impl Drop for ResetTls {
fn drop(&mut self) {
STORAGE.with(|s| s.set(self.0));
}
}
}
pub fn get<R>(f: impl FnOnce(&mut dyn VMStore) -> R) -> R {
try_get(|val| match val {
TryGet::Some(store) => f(store),
TryGet::None => get_failed(false),
TryGet::Taken => get_failed(true),
})
}
#[cold]
fn get_failed(taken: bool) -> ! {
if taken {
panic!(
"attempted to recursively call `Accessor::with` when the pointer \
was already taken by a previous call to `Accessor::with`; try \
using `RUST_BACKTRACE=1` to find two stack frames to \
`Accessor::with` on the stack"
);
} else {
panic!(
"`Accessor::with` was called when the TLS pointer was not \
previously set; this is likely a bug in Wasmtime and we would \
appreciate an issue being filed to help fix this."
);
}
}
pub enum TryGet<'a> {
None,
Taken,
Some(&'a mut dyn VMStore),
}
pub fn try_get<R>(f: impl FnOnce(TryGet<'_>) -> R) -> R {
unsafe {
let storage = STORAGE.with(|s| s.get());
let _reset;
let val = match storage {
Some(mut storage) => match mem::replace(storage.as_mut(), SetStorage::Taken) {
SetStorage::Taken => TryGet::Taken,
SetStorage::Present(mut ptr) => {
_reset = ResetStorage(storage, ptr);
TryGet::Some(ptr.as_mut())
}
},
None => TryGet::None,
};
return f(val);
}
struct ResetStorage(NonNull<SetStorage>, NonNull<dyn VMStore>);
impl Drop for ResetStorage {
fn drop(&mut self) {
unsafe {
*self.0.as_mut() = SetStorage::Present(self.1);
}
}
}
}
#[cfg(test)]
mod tests {
use super::{TryGet, get, set, try_get};
use crate::{AsContextMut, Engine, Store};
#[test]
fn test_simple() {
let engine = Engine::default();
let mut store = Store::new(&engine, ());
set(store.as_context_mut().0, || {
get(|_| {});
try_get(|t| {
assert!(matches!(t, TryGet::Some(_)));
});
});
}
#[test]
fn test_try_get() {
let engine = Engine::default();
let mut store = Store::new(&engine, ());
try_get(|t| {
assert!(matches!(t, TryGet::None));
try_get(|t| {
assert!(matches!(t, TryGet::None));
});
});
set(store.as_context_mut().0, || {
get(|_| {
try_get(|t| {
assert!(matches!(t, TryGet::Taken));
try_get(|t| {
assert!(matches!(t, TryGet::Taken));
});
});
});
try_get(|t| {
assert!(matches!(t, TryGet::Some(_)));
try_get(|t| {
assert!(matches!(t, TryGet::Taken));
try_get(|t| {
assert!(matches!(t, TryGet::Taken));
});
});
});
try_get(|t| {
assert!(matches!(t, TryGet::Some(_)));
try_get(|t| {
assert!(matches!(t, TryGet::Taken));
});
});
});
try_get(|t| {
assert!(matches!(t, TryGet::None));
});
}
#[test]
#[should_panic(expected = "attempted to recursively call")]
fn test_get_panic() {
let engine = Engine::default();
let mut store = Store::new(&engine, ());
set(store.as_context_mut().0, || {
get(|_| {
get(|_| {
panic!("should not get here");
});
});
});
}
}