pros_simulator/host/
thread_local.rs

1use std::mem::size_of;
2
3use async_trait::async_trait;
4use wasmtime::{AsContextMut, SharedMemory};
5
6use super::{memory::SharedMemoryExt, HostCtx, WasmAllocator};
7
8pub const NUM_THREAD_LOCAL_STORAGE_POINTERS: usize = 5;
9
10// #[derive(Debug, Default)]
11// pub struct ThreadLocalStorage {
12//     pub tasks: HashMap<Task, TaskStorage>,
13// }
14
15// impl ThreadLocalStorage {
16//     pub async fn get(
17//         &mut self,
18//         store: impl AsContextMut<Data = impl Send>,
19//         allocator: &WasmAllocator,
20//         task: Task,
21//     ) -> TaskStorage {
22//         if let Some(storage) = self.tasks.get(&task) {
23//             return *storage;
24//         }
25
26//         let storage = TaskStorage::new(store, allocator).await;
27//         self.tasks.insert(task, storage);
28//         storage
29//     }
30// }
31
32#[derive(Debug, Clone, Copy)]
33pub struct TaskStorage {
34    base_ptr: u32,
35}
36
37impl TaskStorage {
38    pub async fn new(
39        store: impl AsContextMut<Data = impl Send>,
40        allocator: &WasmAllocator,
41    ) -> Self {
42        let base_ptr = allocator
43            .memalign(
44                store,
45                std::alloc::Layout::new::<[u32; NUM_THREAD_LOCAL_STORAGE_POINTERS]>(),
46            )
47            .await;
48        Self { base_ptr }
49    }
50
51    fn assert_in_bounds(index: i32) {
52        if index < 0 || index as usize >= NUM_THREAD_LOCAL_STORAGE_POINTERS {
53            panic!(
54                "Thread local storage index out of bounds:\n\
55                index {index} should be more than 0 and less than {NUM_THREAD_LOCAL_STORAGE_POINTERS}."
56            );
57        }
58    }
59    pub fn get_address(&self, index: i32) -> u32 {
60        Self::assert_in_bounds(index);
61
62        self.base_ptr + (index as u32 * size_of::<u32>() as u32)
63    }
64    pub fn get(&self, memory: SharedMemory, index: i32) -> u32 {
65        Self::assert_in_bounds(index);
66        let address = self.get_address(index);
67        let buffer = memory
68            .read_relaxed(address as usize, size_of::<u32>())
69            .unwrap();
70        u32::from_le_bytes(buffer.try_into().unwrap())
71    }
72    pub fn set(&mut self, memory: SharedMemory, index: i32, value: u32) {
73        Self::assert_in_bounds(index);
74        let address = self.get_address(index);
75        let buffer = value.to_le_bytes();
76        memory.write_relaxed(address as usize, &buffer).unwrap();
77    }
78}
79
80#[async_trait]
81pub trait GetTaskStorage {
82    async fn task_storage(&mut self, task_handle: u32) -> TaskStorage;
83}
84
85#[async_trait]
86impl<T, D> GetTaskStorage for T
87where
88    T: HostCtx + wasmtime::AsContextMut<Data = D> + Send,
89    D: Send,
90{
91    async fn task_storage(&mut self, task_handle: u32) -> TaskStorage {
92        let task = self
93            .tasks_lock()
94            .await
95            .by_id(task_handle)
96            .expect("invalid task handle");
97
98        let mut task = task.lock().await;
99        task.local_storage(self).await
100    }
101}