Skip to main content

pros_core/task/
local.rs

1//! A custom TLS implementation that allows for more than 5 entries in TLS.
2//!
3//! FreeRTOS task locals have a hard limit of entries.
4//! The custom implementation used here stores a pointer to a custom TLS struct inside the first slot of FreeRTOS TLS.
5//! This sacrifices a bit of speed for the ability to have as many entries as memory allows.
6//!
7//! [`LocalKey`]s can be created with the [`os_task_local!`](crate::os_task_local!) macro.
8//! ## Example
9//! ```rust
10//! os_task_local! {
11//!     static FOO: u32 = 0;
12//!     static BAR: String = String::from("Hello, world!");
13//! }
14//! ```
15
16use alloc::{boxed::Box, collections::BTreeMap};
17use core::{
18    cell::{Cell, RefCell},
19    ptr::NonNull,
20    sync::atomic::AtomicU32,
21};
22
23use spin::Once;
24
25use super::current;
26
27/// A semaphore that makes sure that each [`LocalKey`] has a unique index into TLS.
28static INDEX: AtomicU32 = AtomicU32::new(0);
29
30/// Set a value in OS TLS.
31/// This requires you to leak val so that you can be sure it lives as long as the task.
32/// # Safety
33/// Unsafe because you can change the thread local storage while it is being read.
34unsafe fn thread_local_storage_set<T>(task: pros_sys::task_t, val: &'static T, index: u32) {
35    // Yes, we transmute val. This is the intended use of this function.
36    // SAFETY: caller must ensure borrow rules are followed
37    unsafe {
38        pros_sys::vTaskSetThreadLocalStoragePointer(task, index as _, (val as *const T).cast());
39    }
40}
41
42/// Get a value from OS TLS.
43/// # Safety
44/// Unsafe because we can't check if the type is the same as the one that was set.
45unsafe fn thread_local_storage_get<T>(task: pros_sys::task_t, index: u32) -> Option<&'static T> {
46    // SAFETY: caller must ensure borrow rules are followed and the type is correct
47    unsafe {
48        let val = pros_sys::pvTaskGetThreadLocalStoragePointer(task, index as _);
49        val.cast::<T>().as_ref()
50    }
51}
52
53/// Get or create the [`ThreadLocalStorage`] for the current task.
54fn fetch_storage() -> &'static RefCell<ThreadLocalStorage> {
55    let current = current();
56
57    // Get the thread local storage for this task.
58    // Creating it if it doesn't exist.
59    // SAFETY: This is safe as long as index 0 of the freeRTOS TLS is never set to any other type.
60    unsafe {
61        thread_local_storage_get(current.task, 0).unwrap_or_else(|| {
62            let storage = Box::leak(Box::new(RefCell::new(ThreadLocalStorage {
63                data: BTreeMap::new(),
64            })));
65            thread_local_storage_set(current.task, storage, 0);
66            storage
67        })
68    }
69}
70
71/// A custom thread local storage implementation.
72/// This itself is stored inside real OS TLS, it allows for more than 5 entries in TLS.
73/// [`LocalKey`]s store their data inside this struct.
74struct ThreadLocalStorage {
75    pub data: BTreeMap<usize, NonNull<()>>,
76}
77
78/// A TLS key that owns its data.
79/// Can be created with the [`os_task_local`](crate::os_task_local!) macro.
80#[derive(Debug)]
81pub struct LocalKey<T: 'static> {
82    index: Once<usize>,
83    init: fn() -> T,
84}
85
86impl<T: 'static> LocalKey<T> {
87    /// Creates a new local key that lazily initializes its data.
88    /// init is called to initialize the data when it is first accessed from a new thread.
89    pub const fn new(init: fn() -> T) -> Self {
90        Self {
91            index: Once::new(),
92            init,
93        }
94    }
95
96    /// Get the index of this key, or get the next one if it has never been created before.
97    fn index(&'static self) -> &usize {
98        self.index
99            .call_once(|| INDEX.fetch_add(1, core::sync::atomic::Ordering::Relaxed) as _)
100    }
101
102    /// Passes a reference to the value of this key to the given closure.
103    /// If the value has not been initialized yet, it will be initialized.
104    pub fn with<F, R>(&'static self, f: F) -> R
105    where
106        F: FnOnce(&'static T) -> R,
107    {
108        self.initialize_with((self.init)(), |_, val| f(val))
109    }
110
111    /// Acquires a reference to the value in this TLS key, initializing it with
112    /// `init` if it wasn't already initialized on this task.
113    ///
114    /// If `init` was used to initialize the task local variable, `None` is
115    /// passed as the first argument to `f`. If it was already initialized,
116    /// `Some(init)` is passed to `f`.
117    fn initialize_with<F, R>(&'static self, init: T, f: F) -> R
118    where
119        F: FnOnce(Option<T>, &'static T) -> R,
120    {
121        let storage = fetch_storage();
122        let index = *self.index();
123
124        if let Some(val) = storage.borrow().data.get(&index) {
125            return f(Some(init), unsafe { val.cast().as_ref() });
126        }
127
128        let val = Box::leak(Box::new(init));
129        storage
130            .borrow_mut()
131            .data
132            .insert(index, NonNull::new((val as *mut T).cast::<()>()).unwrap());
133        f(None, val)
134    }
135}
136
137impl<T: 'static> LocalKey<Cell<T>> {
138    /// Sets or initializes the value of this key.
139    ///
140    /// If the value was already initialized, it is overwritten.
141    /// If the value was not initialized, it is initialized with `value`.
142    pub fn set(&'static self, value: T) {
143        self.initialize_with(Cell::new(value), |value, cell| {
144            if let Some(value) = value {
145                // The cell was already initialized, so `value` wasn't used to
146                // initialize it. So we overwrite the current value with the
147                // new one instead.
148                cell.set(value.into_inner());
149            }
150        });
151    }
152
153    /// Gets a copy of the value in this TLS key.
154    pub fn get(&'static self) -> T
155    where
156        T: Copy,
157    {
158        self.with(|cell| cell.get())
159    }
160
161    /// Takes the value out of this TLS key, replacing it with the [`Default`] value.
162    pub fn take(&'static self) -> T
163    where
164        T: Default,
165    {
166        self.with(|cell| cell.replace(Default::default()))
167    }
168
169    /// Replaces the value in this TLS key with the given one, returning the old value.
170    pub fn replace(&'static self, value: T) -> T {
171        self.with(|cell| cell.replace(value))
172    }
173}
174
175impl<T: 'static> LocalKey<RefCell<T>> {
176    /// Acquires a reference to the contained value, initializing it if required.
177    ///
178    /// # Panics
179    ///
180    /// Panics if the value is currently mutably borrowed.
181    pub fn with_borrow<F, R>(&'static self, f: F) -> R
182    where
183        F: FnOnce(&T) -> R,
184    {
185        self.with(|cell| f(&cell.borrow()))
186    }
187
188    /// Acquires a mutable reference to the contained value, initializing it if required.
189    ///
190    /// # Panics
191    ///
192    /// Panics if the value is currently borrowed.
193    pub fn with_borrow_mut<F, R>(&'static self, f: F) -> R
194    where
195        F: FnOnce(&mut T) -> R,
196    {
197        self.with(|cell| f(&mut cell.borrow_mut()))
198    }
199
200    /// Sets or initializes the value of this key, without running the initializer.
201    ///
202    /// # Panics
203    ///
204    /// Panics if the value is currently borrowed.
205    pub fn set(&'static self, value: T) {
206        self.initialize_with(RefCell::new(value), |value, cell| {
207            if let Some(value) = value {
208                // The cell was already initialized, so `value` wasn't used to
209                // initialize it. So we overwrite the current value with the
210                // new one instead.
211                *cell.borrow_mut() = value.into_inner();
212            }
213        });
214    }
215
216    /// Takes the value out of this TLS key, replacing it with the [`Default`] value.
217    ///
218    /// # Panics
219    ///
220    /// Panics if the value is currently borrowed.
221    pub fn take(&'static self) -> T
222    where
223        T: Default,
224    {
225        self.with(|cell| cell.take())
226    }
227
228    /// Replaces the value in this TLS key with the given one, returning the old value.
229    ///
230    /// # Panics
231    ///
232    /// Panics if the value is currently borrowed.
233    pub fn replace(&'static self, value: T) -> T {
234        self.with(|cell| cell.replace(value))
235    }
236}
237
238/// Create new [`LocalKey`]\(s)
239/// # Example
240/// ```rust
241/// os_task_local! {
242///     static FOO: u32 = 0;
243///     static BAR: String = String::new();
244/// }
245#[macro_export]
246macro_rules! os_task_local {
247    ($($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty = $init:expr;)*) => {
248        $(
249        $(#[$attr])*
250        $vis static $name: $crate::task::local::LocalKey<$t> = $crate::task::local::LocalKey::new(|| $init);
251        )*
252    };
253}