1use std::{
38 cell::RefCell,
39 marker::PhantomData,
40 ops::{Deref, DerefMut},
41 thread::LocalKey,
42};
43
44struct TakeCleanup<T: 'static> {
49 cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
50 cached: Option<T>,
51 armed: bool,
52}
53
54impl<T: 'static> TakeCleanup<T> {
55 const fn disarm(&mut self) {
56 self.armed = false;
57 }
58}
59
60impl<T: 'static> Drop for TakeCleanup<T> {
61 fn drop(&mut self) {
62 if !self.armed {
63 return;
64 }
65 self.cache.with(|cell| {
66 let mut slot = cell.borrow_mut();
67 debug_assert!(slot.0, "cache expected to be held");
68 slot.0 = false;
69 slot.1 = self.cached.take();
70 });
71 }
72}
73
74pub struct Cached<T: 'static> {
80 value: Option<T>,
81 cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
82 _not_send: PhantomData<*const ()>,
83}
84
85impl<T: 'static> Cached<T> {
86 pub fn take<E>(
96 cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
97 create: impl FnOnce() -> Result<T, E>,
98 reset: impl FnOnce(&mut T) -> Result<(), E>,
99 ) -> Result<Self, E> {
100 let cached = cache.with(|cell| {
101 let mut slot = cell.borrow_mut();
102 assert!(!slot.0, "cache already held on this thread");
103 slot.0 = true;
104 slot.1.take()
105 });
106 let mut cleanup = TakeCleanup {
107 cache,
108 cached,
109 armed: true,
110 };
111 let value = match cleanup.cached.take() {
112 Some(mut v) => {
113 if let Err(err) = reset(&mut v) {
114 cleanup.cached = Some(v);
115 return Err(err);
116 }
117 v
118 }
119 None => create()?,
120 };
121 cleanup.disarm();
122 Ok(Self {
123 value: Some(value),
124 cache,
125 _not_send: PhantomData,
126 })
127 }
128}
129
130impl<T: 'static> Deref for Cached<T> {
131 type Target = T;
132
133 fn deref(&self) -> &T {
134 self.value.as_ref().expect("value taken after drop")
135 }
136}
137
138impl<T: 'static> DerefMut for Cached<T> {
139 fn deref_mut(&mut self) -> &mut T {
140 self.value.as_mut().expect("value taken after drop")
141 }
142}
143
144impl<T: 'static> Drop for Cached<T> {
145 fn drop(&mut self) {
146 if let Some(v) = self.value.take() {
147 self.cache.with(|cell| {
148 let mut slot = cell.borrow_mut();
149 debug_assert!(slot.0, "cache expected to be held");
150 slot.0 = false;
151 slot.1 = Some(v);
152 });
153 }
154 }
155}
156
157#[macro_export]
169macro_rules! thread_local_cache {
170 (static $name:ident : $ty:ty) => {
171 ::std::thread_local! {
172 static $name: ::std::cell::RefCell<(bool, ::core::option::Option<$ty>)> =
173 const { ::std::cell::RefCell::new((false, ::core::option::Option::None)) };
174 }
175 };
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 thread_local_cache!(static TEST_CACHE: Vec<u8>);
183
184 #[test]
185 fn test_take_creates_on_miss() {
186 let guard = Cached::take(&TEST_CACHE, || Ok::<_, ()>(vec![1, 2, 3]), |_v| Ok(())).unwrap();
187 assert_eq!(&*guard, &[1, 2, 3]);
188 }
189
190 thread_local_cache!(static REUSE_CACHE: Vec<u8>);
191
192 #[test]
193 fn test_take_reuses_on_hit() {
194 let mut guard = Cached::take(
196 &REUSE_CACHE,
197 || Ok::<_, ()>(vec![1, 2, 3]),
198 |v| {
199 v.clear();
200 Ok(())
201 },
202 )
203 .unwrap();
204 guard.push(4);
205 drop(guard);
206
207 let guard = Cached::take(
209 &REUSE_CACHE,
210 || Ok::<_, ()>(vec![99]),
211 |v| {
212 v.clear();
213 Ok(())
214 },
215 )
216 .unwrap();
217 assert!(guard.is_empty(), "reset should have cleared the vec");
218 }
219
220 thread_local_cache!(static DROP_CACHE: String);
221
222 #[test]
223 fn test_drop_returns_to_cache() {
224 {
225 let _guard = Cached::take(
226 &DROP_CACHE,
227 || Ok::<_, ()>(String::from("hello")),
228 |_| Ok(()),
229 )
230 .unwrap();
231 }
233
234 let has_value = DROP_CACHE.with(|cell| cell.borrow().1.is_some());
236 assert!(has_value, "drop should return value to cache");
237 }
238
239 thread_local_cache!(static ERR_CACHE: u32);
240
241 #[test]
242 fn test_create_error_propagates() {
243 let result = Cached::take(&ERR_CACHE, || Err::<u32, &str>("create failed"), |_| Ok(()));
244 assert!(result.is_err());
245
246 let guard = Cached::take(&ERR_CACHE, || Ok::<u32, &str>(7), |_| Ok(())).unwrap();
248 assert_eq!(*guard, 7);
249 }
250
251 thread_local_cache!(static RESET_ERR_CACHE: u32);
252
253 #[test]
254 fn test_reset_error_propagates() {
255 {
257 let _guard = Cached::take(&RESET_ERR_CACHE, || Ok::<_, &str>(42), |_| Ok(())).unwrap();
258 }
259
260 let result = Cached::take(
262 &RESET_ERR_CACHE,
263 || Ok::<_, &str>(0),
264 |_| Err("reset failed"),
265 );
266 assert!(result.is_err());
267
268 let cached = RESET_ERR_CACHE.with(|cell| cell.borrow().1);
270 assert_eq!(cached, Some(42));
271 }
272
273 thread_local_cache!(static NESTED_CACHE: Vec<u8>);
274
275 #[test]
276 fn test_nested_guards_rejected() {
277 NESTED_CACHE.with(|cell| *cell.borrow_mut() = (false, None));
278
279 let result = std::panic::catch_unwind(|| {
280 let mut outer =
281 Cached::take(&NESTED_CACHE, || Ok::<_, ()>(vec![1]), |_| Ok(())).unwrap();
282 outer.push(10);
283 let _inner = Cached::take(&NESTED_CACHE, || Ok::<_, ()>(vec![2]), |_| Ok(())).unwrap();
284 });
285 assert!(result.is_err(), "nested take on same thread should panic");
286
287 let cached = NESTED_CACHE.with(|cell| cell.borrow().1.clone());
289 assert_eq!(cached, Some(vec![1, 10]));
290 }
291
292 thread_local_cache!(static PANIC_CREATE_CACHE: u32);
293
294 #[test]
295 fn test_create_panic_does_not_poison_held_flag() {
296 let result = std::panic::catch_unwind(|| {
297 let _ = Cached::take(
298 &PANIC_CREATE_CACHE,
299 || -> Result<u32, ()> { panic!("create panic") },
300 |_| Ok(()),
301 );
302 });
303 assert!(result.is_err());
304
305 let guard = Cached::take(&PANIC_CREATE_CACHE, || Ok::<_, ()>(7), |_| Ok(())).unwrap();
306 assert_eq!(*guard, 7);
307 }
308
309 thread_local_cache!(static PANIC_RESET_CACHE: u32);
310
311 #[test]
312 fn test_reset_panic_does_not_poison_held_flag() {
313 {
314 let _guard = Cached::take(&PANIC_RESET_CACHE, || Ok::<_, ()>(42), |_| Ok(())).unwrap();
315 }
316
317 let result = std::panic::catch_unwind(|| {
318 let _ = Cached::take(
319 &PANIC_RESET_CACHE,
320 || Ok::<_, ()>(0),
321 |_| -> Result<(), ()> { panic!("reset panic") },
322 );
323 });
324 assert!(result.is_err());
325
326 let guard = Cached::take(&PANIC_RESET_CACHE, || Ok::<_, ()>(9), |_| Ok(())).unwrap();
327 assert_eq!(*guard, 9);
328 }
329}