fn_map/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3
4pub mod raw;
5
6use core::{cell::UnsafeCell, ptr::NonNull};
7use parking_lot::RwLock;
8use type_key::TypeKey;
9
10use crate::raw::RawFnMap;
11
12#[derive(Debug, Default)]
13/// Single thread only FnMap implementation.
14///
15/// This implementation is zero cost.
16pub struct FnMap(UnsafeCell<RawFnMap>);
17
18impl FnMap {
19    #[inline]
20    pub fn new() -> Self {
21        Self::default()
22    }
23
24    #[inline]
25    pub fn get_ptr<T: 'static + Send>(&self, key_fn: impl FnOnce() -> T) -> NonNull<T> {
26        let key = TypeKey::of_val(&key_fn);
27
28        // SAFETY: safe to borrow shared because self is borrowed shared
29        if let Some(ptr) = unsafe { &*self.0.get().cast_const() }.get(&key) {
30            return ptr;
31        }
32
33        // accuire value first before borrowing exclusively
34        let value = key_fn();
35
36        // SAFETY: safe to borrow exclusively since no one can borrow more
37        unsafe { &mut *self.0.get() }.insert(key, value)
38    }
39
40    /// Get or compute value using key
41    #[inline]
42    pub fn get<T: 'static + Send>(&self, key: impl FnOnce() -> T) -> &T {
43        // SAFETY: pointer is valid and reference cannot outlive more than Self
44        unsafe { self.get_ptr(key).as_ref() }
45    }
46
47    /// Get or compute value using key
48    #[inline]
49    pub fn get_mut<T: 'static + Send>(&mut self, key: impl FnOnce() -> T) -> &mut T {
50        // SAFETY: pointer is valid and reference cannot outlive more than Self
51        unsafe { self.get_ptr(key).as_mut() }
52    }
53
54    /// Reset stored values
55    #[inline]
56    pub fn reset(&mut self) {
57        self.0.get_mut().reset();
58    }
59}
60
61unsafe impl Send for FnMap {}
62
63#[derive(Debug, Default)]
64/// Single thread only and non-Send FnMap implementation
65///
66/// This implementation is zero cost.
67pub struct LocalOnlyFnMap(UnsafeCell<RawFnMap>);
68
69impl LocalOnlyFnMap {
70    #[inline]
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    #[inline]
76    pub fn get_ptr<T: 'static + Send>(&self, key_fn: impl FnOnce() -> T) -> NonNull<T> {
77        let key = TypeKey::of_val(&key_fn);
78
79        // SAFETY: safe to borrow shared because self is borrowed shared
80        if let Some(ptr) = unsafe { &*self.0.get().cast_const() }.get(&key) {
81            return ptr;
82        }
83
84        // accuire value first before borrowing exclusively
85        let value = key_fn();
86
87        // SAFETY: safe to borrow exclusively since no one can borrow more
88        unsafe { &mut *self.0.get() }.insert(key, value)
89    }
90
91    /// Get or compute value using key
92    #[inline]
93    pub fn get<T: 'static + Send>(&self, key: impl FnOnce() -> T) -> &T {
94        // SAFETY: pointer is valid and reference cannot outlive more than Self
95        unsafe { self.get_ptr(key).as_ref() }
96    }
97
98    /// Get or compute value using key
99    #[inline]
100    pub fn get_mut<T: 'static + Send>(&mut self, key: impl FnOnce() -> T) -> &mut T {
101        // SAFETY: pointer is valid and reference cannot outlive more than Self
102        unsafe { self.get_ptr(key).as_mut() }
103    }
104
105    /// Reset stored values
106    #[inline]
107    pub fn reset(&mut self) {
108        self.0.get_mut().reset();
109    }
110}
111
112#[derive(Debug, Default)]
113/// Thread safe FnMap implementation.
114///
115/// Uses parking_lot's [`RwLock`] to accuire mutable access to Map.
116pub struct ConcurrentFnMap(RwLock<RawFnMap>);
117
118impl ConcurrentFnMap {
119    #[inline]
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    #[inline]
125    pub fn get_ptr<T: 'static + Send + Sync>(&self, key_fn: impl FnOnce() -> T) -> NonNull<T> {
126        let key = TypeKey::of_val(&key_fn);
127
128        if let Some(ptr) = self.0.read().get(&key) {
129            return ptr;
130        }
131
132        let value = key_fn();
133
134        self.0.write().insert(key, value)
135    }
136
137    /// Get or compute value using key
138    #[inline]
139    pub fn get<T: 'static + Send + Sync>(&self, key_fn: impl FnOnce() -> T) -> &T {
140        // SAFETY: pointer is valid and reference cannot outlive more than Self
141        unsafe { self.get_ptr(key_fn).as_ref() }
142    }
143
144    /// Get or compute value using key
145    #[inline]
146    pub fn get_mut<T: 'static + Send + Sync, F>(&mut self, key_fn: impl FnOnce() -> T) -> &mut T {
147        // SAFETY: pointer is valid and reference cannot outlive more than Self
148        unsafe { self.get_ptr(key_fn).as_mut() }
149    }
150
151    /// Reset stored values
152    #[inline]
153    pub fn reset(&mut self) {
154        self.0.get_mut().reset();
155    }
156}
157
158unsafe impl Send for ConcurrentFnMap {}
159unsafe impl Sync for ConcurrentFnMap {}
160
161#[cfg(test)]
162mod tests {
163    use crate::LocalOnlyFnMap;
164
165    use super::{ConcurrentFnMap, FnMap};
166
167    #[test]
168    fn test_trait() {
169        const fn is_send<T: Send>() {}
170        const fn is_sync<T: Sync>() {}
171
172        is_send::<FnMap>();
173
174        is_send::<ConcurrentFnMap>();
175        is_sync::<ConcurrentFnMap>();
176    }
177
178    #[test]
179    fn test_local() {
180        let map = FnMap::new();
181
182        fn one() -> i32 {
183            1
184        }
185
186        let b = map.get(|| map.get(one) + 1);
187        let a = map.get(one);
188
189        assert_eq!(*b, 2);
190        assert_eq!(*a, 1);
191    }
192
193    #[test]
194    fn test_local_only() {
195        let map = LocalOnlyFnMap::new();
196
197        fn one() -> i32 {
198            1
199        }
200
201        let b = map.get(|| map.get(one) + 1);
202        let a = map.get(one);
203
204        assert_eq!(*b, 2);
205        assert_eq!(*a, 1);
206    }
207
208    #[test]
209    fn test_atomic() {
210        let map = ConcurrentFnMap::new();
211
212        fn one() -> i32 {
213            1
214        }
215
216        let b = map.get(|| map.get(one) + 1);
217        let a = map.get(one);
218
219        assert_eq!(*b, 2);
220        assert_eq!(*a, 1);
221    }
222}