moduvex_runtime/executor/
task_local.rs1use std::any::Any;
22use std::cell::RefCell;
23use std::collections::HashMap;
24use std::future::Future;
25use std::marker::PhantomData;
26use std::pin::Pin;
27use std::task::{Context, Poll};
28
29thread_local! {
32 static STORAGE: RefCell<HashMap<usize, Box<dyn Any>>> = RefCell::new(HashMap::new());
33}
34
35pub struct TaskLocal<T: 'static> {
41 _marker: PhantomData<T>,
42}
43
44impl<T: 'static> TaskLocal<T> {
45 #[doc(hidden)]
47 pub const fn new() -> Self {
48 Self {
49 _marker: PhantomData,
50 }
51 }
52
53 fn key(&'static self) -> usize {
55 self as *const Self as usize
56 }
57
58 pub fn scope<F: Future>(&'static self, value: T, future: F) -> Scope<T, F> {
62 Scope {
63 key: self,
64 value: Some(value),
65 future,
66 }
67 }
68
69 pub fn with<F, R>(&'static self, f: F) -> R
71 where
72 F: FnOnce(&T) -> R,
73 {
74 self.try_with(f)
75 .expect("TaskLocal::with() called outside of a scope")
76 }
77
78 pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
80 where
81 F: FnOnce(&T) -> R,
82 {
83 STORAGE.with(|s| {
84 let map = s.borrow();
85 match map.get(&self.key()) {
86 Some(boxed) => {
87 let val = boxed.downcast_ref::<T>().expect("TaskLocal type mismatch");
88 Ok(f(val))
89 }
90 None => Err(AccessError),
91 }
92 })
93 }
94}
95
96unsafe impl<T: 'static> Sync for TaskLocal<T> {}
98
99#[derive(Debug)]
103pub struct AccessError;
104
105impl std::fmt::Display for AccessError {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.write_str("task-local value not set in current scope")
108 }
109}
110
111pub struct Scope<T: 'static, F: Future> {
117 key: &'static TaskLocal<T>,
118 value: Option<T>,
120 future: F,
121}
122
123impl<T: 'static, F: Future> Future for Scope<T, F> {
124 type Output = F::Output;
125
126 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 let this = unsafe { self.get_unchecked_mut() };
129
130 let val = this.value.take().expect("Scope polled after completion");
132 let key_addr = this.key.key();
133 let prev = STORAGE.with(|s| s.borrow_mut().insert(key_addr, Box::new(val)));
134
135 let inner = unsafe { Pin::new_unchecked(&mut this.future) };
137 let result = inner.poll(cx);
138
139 let current = STORAGE.with(|s| s.borrow_mut().remove(&key_addr));
141 if let Some(p) = prev {
142 STORAGE.with(|s| s.borrow_mut().insert(key_addr, p));
143 }
144
145 match result {
146 Poll::Ready(output) => Poll::Ready(output),
147 Poll::Pending => {
148 if let Some(boxed) = current {
150 this.value = Some(*boxed.downcast::<T>().expect("type mismatch"));
151 }
152 Poll::Pending
153 }
154 }
155 }
156}
157
158#[macro_export]
167macro_rules! task_local {
168 ($(#[$attr:meta])* $vis:vis static $name:ident : $ty:ty ; $($rest:tt)*) => {
169 $(#[$attr])*
170 $vis static $name: $crate::executor::task_local::TaskLocal<$ty> =
171 $crate::executor::task_local::TaskLocal::new();
172 $crate::task_local!($($rest)*);
173 };
174 () => {};
175}
176
177#[cfg(test)]
180mod tests {
181
182 use crate::executor::{block_on, block_on_with_spawn, spawn};
183
184 task_local! {
185 static FOO: u32;
186 static BAR: String;
187 }
188
189 #[test]
190 fn scope_sets_and_reads_value() {
191 block_on(async {
192 FOO.scope(42, async {
193 FOO.with(|v| assert_eq!(*v, 42));
194 })
195 .await;
196 });
197 }
198
199 #[test]
200 fn try_with_returns_err_outside_scope() {
201 block_on(async {
202 assert!(FOO.try_with(|_| ()).is_err());
203 });
204 }
205
206 #[test]
207 fn nested_scopes_restore_previous() {
208 block_on(async {
209 FOO.scope(1, async {
210 FOO.with(|v| assert_eq!(*v, 1));
211 FOO.scope(2, async {
212 FOO.with(|v| assert_eq!(*v, 2));
213 })
214 .await;
215 FOO.with(|v| assert_eq!(*v, 1));
217 })
218 .await;
219 });
220 }
221
222 #[test]
223 fn multiple_keys_independent() {
224 block_on(async {
225 FOO.scope(99, async {
226 BAR.scope(String::from("hello"), async {
227 FOO.with(|v| assert_eq!(*v, 99));
228 BAR.with(|v| assert_eq!(v, "hello"));
229 })
230 .await;
231 })
232 .await;
233 });
234 }
235
236 #[test]
237 fn scope_value_not_visible_after_await() {
238 block_on(async {
239 FOO.scope(10, async {}).await;
240 assert!(FOO.try_with(|_| ()).is_err());
241 });
242 }
243
244 #[test]
245 fn spawned_task_does_not_inherit_parent_scope() {
246 block_on_with_spawn(async {
247 FOO.scope(777, async {
248 let jh = spawn(async { FOO.try_with(|_| ()).is_err() });
249 assert!(
250 jh.await.unwrap(),
251 "spawned task should not see parent scope"
252 );
253 })
254 .await;
255 });
256 }
257}