Skip to main content

moduvex_runtime/executor/
task_local.rs

1//! Task-local storage for async contexts.
2//!
3//! Provides [`TaskLocal<T>`] — a key for per-task storage, analogous to
4//! `thread_local!` but scoped to an async task's execution. Values are set
5//! via [`TaskLocal::scope`] and read via [`TaskLocal::with`] /
6//! [`TaskLocal::try_with`].
7//!
8//! # Example
9//! ```
10//! moduvex_runtime::task_local! {
11//!     static REQUEST_ID: u64;
12//! }
13//!
14//! moduvex_runtime::block_on(async {
15//!     REQUEST_ID.scope(42, async {
16//!         REQUEST_ID.with(|id| assert_eq!(*id, 42));
17//!     }).await;
18//! });
19//! ```
20
21use std::any::Any;
22use std::cell::RefCell;
23use std::collections::HashMap;
24use std::future::Future;
25use std::marker::PhantomData;
26use std::pin::Pin;
27use std::task::{Context, Poll};
28
29// ── Thread-local storage backend ─────────────────────────────────────────────
30
31thread_local! {
32    static STORAGE: RefCell<HashMap<usize, Box<dyn Any>>> = RefCell::new(HashMap::new());
33}
34
35// ── TaskLocal key ────────────────────────────────────────────────────────────
36
37/// A key for task-local storage, created by the [`task_local!`] macro.
38///
39/// Each static `TaskLocal<T>` has a unique address used as the storage key.
40pub struct TaskLocal<T: 'static> {
41    _marker: PhantomData<T>,
42}
43
44impl<T: 'static> TaskLocal<T> {
45    /// Internal constructor — use [`task_local!`] instead.
46    #[doc(hidden)]
47    pub const fn new() -> Self {
48        Self {
49            _marker: PhantomData,
50        }
51    }
52
53    /// Address-based key for the thread-local HashMap.
54    fn key(&'static self) -> usize {
55        self as *const Self as usize
56    }
57
58    /// Run `future` with `value` set for this key. Restores the previous
59    /// value (if any) after each poll, so other tasks on the same thread
60    /// don't see stale data.
61    pub fn scope<F: Future>(&'static self, value: T, future: F) -> Scope<T, F> {
62        Scope {
63            key: self,
64            value: Some(value),
65            future,
66        }
67    }
68
69    /// Access the current value, panicking if no scope is active.
70    pub fn with<F, R>(&'static self, f: F) -> R
71    where
72        F: FnOnce(&T) -> R,
73    {
74        self.try_with(f)
75            .expect("TaskLocal::with() called outside of a scope")
76    }
77
78    /// Access the current value, returning `Err` if no scope is active.
79    pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
80    where
81        F: FnOnce(&T) -> R,
82    {
83        STORAGE.with(|s| {
84            let map = s.borrow();
85            match map.get(&self.key()) {
86                Some(boxed) => {
87                    let val = boxed.downcast_ref::<T>().expect("TaskLocal type mismatch");
88                    Ok(f(val))
89                }
90                None => Err(AccessError),
91            }
92        })
93    }
94}
95
96// SAFETY: TaskLocal itself holds no data — all data lives in thread-local storage.
97unsafe impl<T: 'static> Sync for TaskLocal<T> {}
98
99// ── AccessError ──────────────────────────────────────────────────────────────
100
101/// Returned by [`TaskLocal::try_with`] when no scope is active.
102#[derive(Debug)]
103pub struct AccessError;
104
105impl std::fmt::Display for AccessError {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.write_str("task-local value not set in current scope")
108    }
109}
110
111// ── Scope future ─────────────────────────────────────────────────────────────
112
113/// A future that sets a task-local value around each poll of an inner future.
114///
115/// Created by [`TaskLocal::scope`].
116pub struct Scope<T: 'static, F: Future> {
117    key: &'static TaskLocal<T>,
118    /// Holds our value when the inner future is *not* being polled.
119    value: Option<T>,
120    future: F,
121}
122
123impl<T: 'static, F: Future> Future for Scope<T, F> {
124    type Output = F::Output;
125
126    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        // SAFETY: We only project to `future` (pinned) and `value`/`key` (Unpin).
128        let this = unsafe { self.get_unchecked_mut() };
129
130        // Enter: move our value into thread-local storage, capture previous.
131        let val = this.value.take().expect("Scope polled after completion");
132        let key_addr = this.key.key();
133        let prev = STORAGE.with(|s| s.borrow_mut().insert(key_addr, Box::new(val)));
134
135        // Poll the inner future.
136        let inner = unsafe { Pin::new_unchecked(&mut this.future) };
137        let result = inner.poll(cx);
138
139        // Exit: take value back, restore previous.
140        let current = STORAGE.with(|s| s.borrow_mut().remove(&key_addr));
141        if let Some(p) = prev {
142            STORAGE.with(|s| s.borrow_mut().insert(key_addr, p));
143        }
144
145        match result {
146            Poll::Ready(output) => Poll::Ready(output),
147            Poll::Pending => {
148                // Stash value back for next poll.
149                if let Some(boxed) = current {
150                    this.value = Some(*boxed.downcast::<T>().expect("type mismatch"));
151                }
152                Poll::Pending
153            }
154        }
155    }
156}
157
158/// Declare a task-local key.
159///
160/// # Example
161/// ```
162/// moduvex_runtime::task_local! {
163///     static MY_KEY: String;
164/// }
165/// ```
166#[macro_export]
167macro_rules! task_local {
168    ($(#[$attr:meta])* $vis:vis static $name:ident : $ty:ty ; $($rest:tt)*) => {
169        $(#[$attr])*
170        $vis static $name: $crate::executor::task_local::TaskLocal<$ty> =
171            $crate::executor::task_local::TaskLocal::new();
172        $crate::task_local!($($rest)*);
173    };
174    () => {};
175}
176
177// ── Tests ────────────────────────────────────────────────────────────────────
178
179#[cfg(test)]
180mod tests {
181
182    use crate::executor::{block_on, block_on_with_spawn, spawn};
183
184    task_local! {
185        static FOO: u32;
186        static BAR: String;
187    }
188
189    #[test]
190    fn scope_sets_and_reads_value() {
191        block_on(async {
192            FOO.scope(42, async {
193                FOO.with(|v| assert_eq!(*v, 42));
194            })
195            .await;
196        });
197    }
198
199    #[test]
200    fn try_with_returns_err_outside_scope() {
201        block_on(async {
202            assert!(FOO.try_with(|_| ()).is_err());
203        });
204    }
205
206    #[test]
207    fn nested_scopes_restore_previous() {
208        block_on(async {
209            FOO.scope(1, async {
210                FOO.with(|v| assert_eq!(*v, 1));
211                FOO.scope(2, async {
212                    FOO.with(|v| assert_eq!(*v, 2));
213                })
214                .await;
215                // Outer scope restored.
216                FOO.with(|v| assert_eq!(*v, 1));
217            })
218            .await;
219        });
220    }
221
222    #[test]
223    fn multiple_keys_independent() {
224        block_on(async {
225            FOO.scope(99, async {
226                BAR.scope(String::from("hello"), async {
227                    FOO.with(|v| assert_eq!(*v, 99));
228                    BAR.with(|v| assert_eq!(v, "hello"));
229                })
230                .await;
231            })
232            .await;
233        });
234    }
235
236    #[test]
237    fn scope_value_not_visible_after_await() {
238        block_on(async {
239            FOO.scope(10, async {}).await;
240            assert!(FOO.try_with(|_| ()).is_err());
241        });
242    }
243
244    #[test]
245    fn spawned_task_does_not_inherit_parent_scope() {
246        block_on_with_spawn(async {
247            FOO.scope(777, async {
248                let jh = spawn(async { FOO.try_with(|_| ()).is_err() });
249                assert!(
250                    jh.await.unwrap(),
251                    "spawned task should not see parent scope"
252                );
253            })
254            .await;
255        });
256    }
257}