vexide_async/local.rs
1//! Task-local storage
2//!
3//! Task-local storage is a way to create global variables specific to the current task that live
4//! for the entirety of the task's lifetime, almost like statics. Since they are local to the task,
5//! they implement [`Send`] and [`Sync`], regardless of what the underlying data does or does not
6//! implement.
7//!
8//! Task-locals can be declared using the [`task_local`] macro, which creates a [`LocalKey`] with
9//! the same name that can be used to access the local.
10
11use std::{
12 any::Any,
13 boxed::Box,
14 cell::{BorrowError, BorrowMutError, Cell, RefCell, UnsafeCell},
15 collections::btree_map::BTreeMap,
16 ptr,
17 rc::Rc,
18 sync::{
19 atomic::{AtomicU32, Ordering},
20 LazyLock,
21 },
22};
23
24use crate::executor::EXECUTOR;
25
26/// A variable stored in task-local storage.
27///
28/// # Usage
29///
30/// The primary mode of accessing this is through the [`LocalKey::with`] method. For
31/// [`LocalKey<RefCell<T>>`] and [`LocalKey<Cell<T>>`], additional convenience methods are added
32/// that mirror the underlying [`RefCell<T>`] or [`Cell<T>`]'s methods.
33///
34/// # Examples
35///
36/// ```
37/// task_local! {
38/// static PHI: f64 = 1.61803;
39/// static COUNTER: Cell<u32> = Cell::new(0);
40/// static NAMES: RefCell<Vec<String>> = RefCell::new(Vec::new());
41/// }
42///
43/// // LocalKey::with accepts a function and applies it to a reference, returning whatever value
44/// // the function returned
45/// let double_phi = PHI.with(|&phi| phi * 2.0);
46/// assert_eq!(double_phi, 1.61803 * 2.0);
47///
48/// // We can use interior mutability
49/// COUNTER.set(1);
50/// assert_eq!(COUNTER.get(), 1);
51///
52/// NAMES.with_borrow_mut(|names| names.push(String::from("Johnny")));
53/// NAMES.with_borrow(|names| assert_eq!(names.len(), 1));
54///
55/// use vexide::async_runtime::spawn;
56///
57/// // Creating another task
58/// spawn(async {
59/// // The locals of the previous task are completely different.
60/// assert_eq!(COUNTER.get(), 0);
61/// NAME.with_borrow(|names| assert_eq!(names.len(), 0));
62/// }).await;
63/// ```
64#[derive(Debug)]
65pub struct LocalKey<T: 'static> {
66 init: fn() -> T,
67 key: LazyLock<u32>,
68}
69
70unsafe impl<T> Sync for LocalKey<T> {}
71unsafe impl<T> Send for LocalKey<T> {}
72
73/// Declares task-local variables in [`LocalKey`]s of the same names.
74///
75/// # Examples
76///
77/// ```
78/// task_local! {
79/// static PHI: f64 = 1.61803;
80/// static COUNTER: Cell<u32> = Cell::new(0);
81/// static NAMES: RefCell<Vec<String>> = RefCell::new(Vec::new());
82/// }
83/// ```
84#[expect(
85 edition_2024_expr_fragment_specifier,
86 reason = "allows matching `const` expressions"
87)]
88#[macro_export]
89macro_rules! task_local {
90 {
91 $(#[$attr:meta])*
92 $vis:vis static $name:ident: $type:ty = $init:expr;
93 } => {
94 $(#[$attr])*
95 // publicly reexported in crate::task
96 $vis static $name: $crate::task::LocalKey<$type> = {
97 fn init() -> $type { $init }
98 $crate::task::LocalKey::new(init)
99 };
100 };
101
102 {
103 $(#[$attr:meta])*
104 $vis:vis static $name:ident: $type:ty = $init:expr;
105 $($rest:tt)*
106 } => {
107 $crate::task_local!($vis static $name: $type = $init;);
108 $crate::task_local!($($rest)*);
109 }
110}
111pub use task_local;
112
113impl<T: 'static> LocalKey<T> {
114 #[doc(hidden)]
115 pub const fn new(init: fn() -> T) -> Self {
116 static LOCAL_KEY_COUNTER: AtomicU32 = AtomicU32::new(0);
117
118 Self {
119 init,
120 key: LazyLock::new(|| LOCAL_KEY_COUNTER.fetch_add(1, Ordering::Relaxed)),
121 }
122 }
123
124 /// Obtains a reference to the local and applies it to the function `f`, returning whatever `f`
125 /// returned.
126 ///
127 /// # Examples
128 ///
129 /// ```
130 /// task_local! {
131 /// static PHI: f64 = 1.61803;
132 /// }
133 ///
134 /// let double_phi = PHI.with(|&phi| phi * 2.0);
135 /// assert_eq!(double_phi, 1.61803 * 2.0);
136 /// ```
137 pub fn with<F, R>(&'static self, f: F) -> R
138 where
139 F: FnOnce(&T) -> R,
140 {
141 TaskLocalStorage::with_current(|storage| {
142 // SAFETY: get_or_init is always called with the same return type, T
143 // Also, `key` is unique for this local key.
144 f(unsafe { storage.get_or_init(*self.key, self.init) })
145 })
146 }
147}
148
149impl<T: 'static> LocalKey<Cell<T>> {
150 /// Returns a copy of the contained value.
151 pub fn get(&'static self) -> T
152 where
153 T: Copy,
154 {
155 self.with(Cell::get)
156 }
157
158 /// Sets the contained value.
159 pub fn set(&'static self, value: T) {
160 self.with(|cell| cell.set(value));
161 }
162
163 /// Takes the value of contained value, leaving [`Default::default()`] in its place.
164 pub fn take(&'static self) -> T
165 where
166 T: Default,
167 {
168 self.with(Cell::take)
169 }
170
171 /// Replaces the contained value with `value`, returning the old contained value.
172 pub fn replace(&'static self, value: T) -> T {
173 self.with(|cell| cell.replace(value))
174 }
175}
176
177impl<T: 'static> LocalKey<RefCell<T>> {
178 /// Immutably borrows from the [`RefCell`] and applies the obtained reference to `f`.
179 ///
180 /// # Panics
181 ///
182 /// Panics if the value is currently mutably borrowed. For a non-panicking variant, use
183 /// [`LocalKey::try_with_borrow`].
184 pub fn with_borrow<F, R>(&'static self, f: F) -> R
185 where
186 F: FnOnce(&T) -> R,
187 {
188 self.with(|cell| f(&cell.borrow()))
189 }
190
191 /// Mutably borrows from the [`RefCell`] and applies the obtained reference to `f`.
192 ///
193 /// # Panics
194 ///
195 /// Panics if the value is currently borrowed. For a non-panicking variant, use
196 /// [`LocalKey::try_with_borrow_mut`].
197 pub fn with_borrow_mut<F, R>(&'static self, f: F) -> R
198 where
199 F: FnOnce(&mut T) -> R,
200 {
201 self.with(|cell| f(&mut cell.borrow_mut()))
202 }
203
204 /// Tries to immutably borrow the contained value, returning an error if it is currently
205 /// mutably borrowed, and applies the obtained reference to `f`.
206 ///
207 /// This is the non-panicking variant of [`LocalKey::with_borrow`].
208 ///
209 /// # Errors
210 ///
211 /// Returns [`BorrowError`] if the contained value is currently mutably borrowed.
212 pub fn try_with_borrow<F, R>(&'static self, f: F) -> Result<R, BorrowError>
213 where
214 F: FnOnce(&T) -> R,
215 {
216 self.with(|cell| cell.try_borrow().map(|value| f(&value)))
217 }
218
219 /// Tries to mutably borrow the contained value, returning an error if it is currently borrowed,
220 /// and applies the obtained reference to `f`.
221 ///
222 /// This is the non-panicking variant of [`LocalKey::with_borrow_mut`].
223 ///
224 /// # Errors
225 ///
226 /// Returns [`BorrowMutError`] if the contained value is currently borrowed.
227 pub fn try_with_borrow_mut<F, R>(&'static self, f: F) -> Result<R, BorrowMutError>
228 where
229 F: FnOnce(&T) -> R,
230 {
231 self.with(|cell| cell.try_borrow_mut().map(|value| f(&value)))
232 }
233
234 /// Sets the contained value.
235 ///
236 /// # Panics
237 ///
238 /// Panics if the value is currently borrowed.
239 pub fn set(&'static self, value: T) {
240 self.with_borrow_mut(|refmut| *refmut = value);
241 }
242
243 /// Takes the contained value, leaving [`Default::default()`] in its place.
244 ///
245 /// # Panics
246 ///
247 /// Panics if the value is currently borrowed.
248 pub fn take(&'static self) -> T
249 where
250 T: Default,
251 {
252 self.with(RefCell::take)
253 }
254
255 /// Replaces the contained value with `value`, returning the old contained value.
256 ///
257 /// # Panics
258 ///
259 /// Panics if the value is currently borrowed.
260 pub fn replace(&'static self, value: T) -> T {
261 self.with(|cell| cell.replace(value))
262 }
263}
264
265struct ErasedTaskLocal {
266 value: Box<dyn Any>,
267}
268
269impl ErasedTaskLocal {
270 #[doc(hidden)]
271 fn new<T: 'static>(value: T) -> Self {
272 Self {
273 value: Box::new(value),
274 }
275 }
276
277 /// # Safety
278 ///
279 /// Caller guarantees T is the right type
280 unsafe fn get<T: 'static>(&self) -> &T {
281 if cfg!(debug_assertions) {
282 self.value.downcast_ref().unwrap()
283 } else {
284 unsafe { &*ptr::from_ref(&*self.value).cast() }
285 }
286 }
287}
288
289// Fallback TLS block for when reading from outside of a task.
290thread_local! {
291 static FALLBACK_TLS: TaskLocalStorage = const { TaskLocalStorage::new() };
292}
293
294#[derive(Debug)]
295pub(crate) struct TaskLocalStorage {
296 locals: UnsafeCell<BTreeMap<u32, ErasedTaskLocal>>,
297}
298
299impl TaskLocalStorage {
300 pub(crate) const fn new() -> Self {
301 Self {
302 locals: UnsafeCell::new(BTreeMap::new()),
303 }
304 }
305
306 pub(crate) fn scope(value: Rc<TaskLocalStorage>, scope: impl FnOnce()) {
307 let outer_scope = EXECUTOR.with(|ex| (*ex.tls.borrow_mut()).replace(value));
308
309 scope();
310
311 EXECUTOR.with(|ex| {
312 *ex.tls.borrow_mut() = outer_scope;
313 });
314 }
315
316 /// Gets the Task Local Storage data for the current task.
317 pub(crate) fn with_current<F, R>(f: F) -> R
318 where
319 F: FnOnce(&Self) -> R,
320 {
321 EXECUTOR.with(|ex| {
322 if let Some(tls) = ex.tls.borrow().as_ref() {
323 f(tls)
324 } else {
325 FALLBACK_TLS.with(|fallback| f(fallback))
326 }
327 })
328 }
329
330 /// Gets a reference to the Task Local Storage item identified by the given key.
331 ///
332 /// It is invalid to call this function multiple times with the same key and a different `T`.
333 pub(crate) unsafe fn get_or_init<T: 'static>(&self, key: u32, init: fn() -> T) -> &T {
334 // We need to be careful to not make mutable references to values already inserted into the
335 // map because the current task might have existing shared references to that data.
336 // It's okay if the pointer (ErasedTaskLocal) gets moved around, we just can't
337 // assert invalid exclusive access over its contents.
338
339 let locals = self.locals.get();
340 unsafe {
341 // init() could initialize another task local recursively, so we need to be sure there's no mutable
342 // reference to `self.locals` when we call it. We can't use the entry API because of this.
343
344 #[expect(
345 clippy::map_entry,
346 reason = "cannot hold mutable reference over init() call"
347 )]
348 if !(*locals).contains_key(&key) {
349 let new_value = ErasedTaskLocal::new(init());
350 (*locals).insert(key, new_value);
351 }
352
353 (*locals).get(&key).unwrap().get()
354 }
355 }
356}