use std::{
cell::RefCell,
marker::PhantomData,
ops::{Deref, DerefMut},
thread::LocalKey,
};
struct TakeCleanup<T: 'static> {
cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
cached: Option<T>,
armed: bool,
}
impl<T: 'static> TakeCleanup<T> {
const fn disarm(&mut self) {
self.armed = false;
}
}
impl<T: 'static> Drop for TakeCleanup<T> {
fn drop(&mut self) {
if !self.armed {
return;
}
self.cache.with(|cell| {
let mut slot = cell.borrow_mut();
debug_assert!(slot.0, "cache expected to be held");
slot.0 = false;
slot.1 = self.cached.take();
});
}
}
pub struct Cached<T: 'static> {
value: Option<T>,
cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
_not_send: PhantomData<*const ()>,
}
impl<T: 'static> Cached<T> {
pub fn take<E>(
cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
create: impl FnOnce() -> Result<T, E>,
reset: impl FnOnce(&mut T) -> Result<(), E>,
) -> Result<Self, E> {
let cached = cache.with(|cell| {
let mut slot = cell.borrow_mut();
assert!(!slot.0, "cache already held on this thread");
slot.0 = true;
slot.1.take()
});
let mut cleanup = TakeCleanup {
cache,
cached,
armed: true,
};
let value = match cleanup.cached.take() {
Some(mut v) => {
if let Err(err) = reset(&mut v) {
cleanup.cached = Some(v);
return Err(err);
}
v
}
None => create()?,
};
cleanup.disarm();
Ok(Self {
value: Some(value),
cache,
_not_send: PhantomData,
})
}
}
impl<T: 'static> Deref for Cached<T> {
type Target = T;
fn deref(&self) -> &T {
self.value.as_ref().expect("value taken after drop")
}
}
impl<T: 'static> DerefMut for Cached<T> {
fn deref_mut(&mut self) -> &mut T {
self.value.as_mut().expect("value taken after drop")
}
}
impl<T: 'static> Drop for Cached<T> {
fn drop(&mut self) {
if let Some(v) = self.value.take() {
self.cache.with(|cell| {
let mut slot = cell.borrow_mut();
debug_assert!(slot.0, "cache expected to be held");
slot.0 = false;
slot.1 = Some(v);
});
}
}
}
#[macro_export]
macro_rules! thread_local_cache {
(static $name:ident : $ty:ty) => {
::std::thread_local! {
static $name: ::std::cell::RefCell<(bool, ::core::option::Option<$ty>)> =
const { ::std::cell::RefCell::new((false, ::core::option::Option::None)) };
}
};
}
#[cfg(test)]
mod tests {
use super::*;
thread_local_cache!(static TEST_CACHE: Vec<u8>);
#[test]
fn test_take_creates_on_miss() {
let guard = Cached::take(&TEST_CACHE, || Ok::<_, ()>(vec![1, 2, 3]), |_v| Ok(())).unwrap();
assert_eq!(&*guard, &[1, 2, 3]);
}
thread_local_cache!(static REUSE_CACHE: Vec<u8>);
#[test]
fn test_take_reuses_on_hit() {
let mut guard = Cached::take(
&REUSE_CACHE,
|| Ok::<_, ()>(vec![1, 2, 3]),
|v| {
v.clear();
Ok(())
},
)
.unwrap();
guard.push(4);
drop(guard);
let guard = Cached::take(
&REUSE_CACHE,
|| Ok::<_, ()>(vec![99]),
|v| {
v.clear();
Ok(())
},
)
.unwrap();
assert!(guard.is_empty(), "reset should have cleared the vec");
}
thread_local_cache!(static DROP_CACHE: String);
#[test]
fn test_drop_returns_to_cache() {
{
let _guard = Cached::take(
&DROP_CACHE,
|| Ok::<_, ()>(String::from("hello")),
|_| Ok(()),
)
.unwrap();
}
let has_value = DROP_CACHE.with(|cell| cell.borrow().1.is_some());
assert!(has_value, "drop should return value to cache");
}
thread_local_cache!(static ERR_CACHE: u32);
#[test]
fn test_create_error_propagates() {
let result = Cached::take(&ERR_CACHE, || Err::<u32, &str>("create failed"), |_| Ok(()));
assert!(result.is_err());
let guard = Cached::take(&ERR_CACHE, || Ok::<u32, &str>(7), |_| Ok(())).unwrap();
assert_eq!(*guard, 7);
}
thread_local_cache!(static RESET_ERR_CACHE: u32);
#[test]
fn test_reset_error_propagates() {
{
let _guard = Cached::take(&RESET_ERR_CACHE, || Ok::<_, &str>(42), |_| Ok(())).unwrap();
}
let result = Cached::take(
&RESET_ERR_CACHE,
|| Ok::<_, &str>(0),
|_| Err("reset failed"),
);
assert!(result.is_err());
let cached = RESET_ERR_CACHE.with(|cell| cell.borrow().1);
assert_eq!(cached, Some(42));
}
thread_local_cache!(static NESTED_CACHE: Vec<u8>);
#[test]
fn test_nested_guards_rejected() {
NESTED_CACHE.with(|cell| *cell.borrow_mut() = (false, None));
let result = std::panic::catch_unwind(|| {
let mut outer =
Cached::take(&NESTED_CACHE, || Ok::<_, ()>(vec![1]), |_| Ok(())).unwrap();
outer.push(10);
let _inner = Cached::take(&NESTED_CACHE, || Ok::<_, ()>(vec![2]), |_| Ok(())).unwrap();
});
assert!(result.is_err(), "nested take on same thread should panic");
let cached = NESTED_CACHE.with(|cell| cell.borrow().1.clone());
assert_eq!(cached, Some(vec![1, 10]));
}
thread_local_cache!(static PANIC_CREATE_CACHE: u32);
#[test]
fn test_create_panic_does_not_poison_held_flag() {
let result = std::panic::catch_unwind(|| {
let _ = Cached::take(
&PANIC_CREATE_CACHE,
|| -> Result<u32, ()> { panic!("create panic") },
|_| Ok(()),
);
});
assert!(result.is_err());
let guard = Cached::take(&PANIC_CREATE_CACHE, || Ok::<_, ()>(7), |_| Ok(())).unwrap();
assert_eq!(*guard, 7);
}
thread_local_cache!(static PANIC_RESET_CACHE: u32);
#[test]
fn test_reset_panic_does_not_poison_held_flag() {
{
let _guard = Cached::take(&PANIC_RESET_CACHE, || Ok::<_, ()>(42), |_| Ok(())).unwrap();
}
let result = std::panic::catch_unwind(|| {
let _ = Cached::take(
&PANIC_RESET_CACHE,
|| Ok::<_, ()>(0),
|_| -> Result<(), ()> { panic!("reset panic") },
);
});
assert!(result.is_err());
let guard = Cached::take(&PANIC_RESET_CACHE, || Ok::<_, ()>(9), |_| Ok(())).unwrap();
assert_eq!(*guard, 9);
}
}