cubecl_common/
map.rs

1use crate::stub::{Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
2use hashbrown::HashMap;
3
4/// A thread-safe map that allows concurrent access to values using read-write locks.
5pub struct SharedStateMap<K, V> {
6    state: Mutex<Option<State<K, V>>>,
7}
8
9type State<K, V> = HashMap<K, Arc<RwLock<V>>>;
10
11/// A value in the [SharedStateMap] that provides read and write access.
12pub struct SharedState<V> {
13    val: Arc<RwLock<V>>,
14}
15
16impl<V> SharedState<V> {
17    /// Acquires a read lock on the value, returning a read guard.
18    pub fn read(&self) -> RwLockReadGuard<'_, V> {
19        self.val.read().unwrap()
20    }
21
22    /// Acquires a write lock on the value, returning a write guard.
23    pub fn write(&self) -> RwLockWriteGuard<'_, V> {
24        self.val.write().unwrap()
25    }
26}
27
28impl<K, V> Default for SharedStateMap<K, V>
29where
30    K: core::hash::Hash + core::cmp::PartialEq + core::cmp::Eq,
31{
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<K, V> SharedStateMap<K, V>
38where
39    K: core::hash::Hash + core::cmp::PartialEq + core::cmp::Eq,
40{
41    /// Creates a new, empty `SharedStateMap`.
42    pub const fn new() -> Self {
43        Self {
44            state: Mutex::new(None),
45        }
46    }
47
48    /// Retrieves a value associated with the given key, if it exists.
49    pub fn get(&self, k: &K) -> Option<SharedState<V>> {
50        let mut state = self.state.lock().unwrap();
51        let map = get_or_init::<K, V>(&mut state);
52
53        match map.get(k) {
54            Some(val) => Some(SharedState { val: val.clone() }),
55            None => None,
56        }
57    }
58
59    /// Retrieves a value associated with the given key, or inserts a new value using the provided
60    /// initializer function if the key does not exist.
61    pub fn get_or_init<Fn: FnMut(&K) -> V>(&self, k: &K, mut init: Fn) -> SharedState<V>
62    where
63        K: Clone,
64    {
65        let mut state = self.state.lock().unwrap();
66        let map = get_or_init::<K, V>(&mut state);
67
68        match map.get(k) {
69            Some(val) => SharedState { val: val.clone() },
70            None => {
71                let val = init(k);
72                let val = Arc::new(RwLock::new(val));
73                map.insert(k.clone(), val.clone());
74                SharedState { val: val.clone() }
75            }
76        }
77    }
78
79    /// Inserts a key-value pair into the map.
80    pub fn insert(&self, k: K, v: V) {
81        let mut state = self.state.lock().unwrap();
82        let map = get_or_init::<K, V>(&mut state);
83
84        map.insert(k, Arc::new(RwLock::new(v)));
85    }
86
87    /// Clears the map, removing all key-value pairs.
88    pub fn clear(&self) {
89        let mut state = self.state.lock().unwrap();
90        let map = get_or_init::<K, V>(&mut state);
91        map.clear();
92    }
93}
94
95fn get_or_init<K, V>(state: &mut Option<State<K, V>>) -> &mut State<K, V> {
96    match state {
97        Some(state) => state,
98        None => {
99            *state = Some(State::<K, V>::default());
100            state.as_mut().unwrap()
101        }
102    }
103}