Skip to main content

linked/
instance_per_thread.rs

1use std::collections::{HashMap, hash_map};
2use std::ops::Deref;
3use std::rc::Rc;
4use std::sync::{Arc, RwLock};
5use std::thread::{self, ThreadId};
6
7use simple_mermaid::mermaid;
8
9use crate::{BuildThreadIdHasher, ERR_POISONED_LOCK};
10
11/// A wrapper that manages linked instances of `T`, ensuring that only one
12/// instance of `T` is created per thread.
13///
14/// This is similar to the [`linked::thread_local_rc!` macro][1], with the main difference
15/// being that this type operates entirely at runtime using dynamic storage and does
16/// not require a static variable to be defined.
17///
18/// # Usage
19///
20/// Create an instance of `InstancePerThread` and provide it the initial instance of a linked
21/// object `T`. Any instance of `T` accessed through the same `InstancePerThread` or a clone of it
22/// will be linked to the same family.
23#[ doc=mermaid!( "../doc/instance_per_thread.mermaid") ]
24///
25/// To access the current thread's instance of `T`, you must first obtain a
26/// [`Ref<T>`][Ref] by calling [`.acquire()`][Self::acquire]. Then you can
27/// access the `T` within by simply dereferencing via the `Deref<Target = T>` trait.
28///
29/// `Ref<T>` is a thread-isolated type, meaning you cannot move it to a different
30/// thread nor access it from a different thread. To access linked instances on other threads,
31/// you must transfer the `InstancePerThread<T>` instance across threads and obtain a new `Ref<T>`
32/// on the destination thread.
33///
34/// # Resource management
35///
36/// A thread-specific instance of `T` is dropped when the last `Ref` on that thread
37/// is dropped, similar to how `Rc<T>` would behave. If a new `Ref` is later obtained,
38/// it is initialized with a new instance of the linked object.
39///
40/// It is important to emphasize that this means if you only acquire temporary `Ref`
41/// objects then you will get a new instance of `T` every time. The performance impact of
42/// this depends on how `T` works internally but you are recommended to keep `Ref`
43/// instances around for reuse when possible.
44///
45/// # `Ref` storage
46///
47/// `Ref` is a thread-isolated type, which means you cannot store it in places that require types
48/// to be thread-mobile. For example, in web framework request handlers the compiler might not
49/// permit you to let a `Ref` live across an `await`, depending on the web framework, the async
50/// task runtime used and its specific configuration.
51///
52/// Consider using `InstancePerThreadSync<T>` if you need a thread-mobile variant of `Ref`.
53///
54/// [1]: crate::thread_local_rc
55#[derive(Debug)]
56pub struct InstancePerThread<T>
57where
58    T: linked::Object,
59{
60    family: FamilyStateReference<T>,
61}
62
63impl<T> InstancePerThread<T>
64where
65    T: linked::Object,
66{
67    /// Creates a new `InstancePerThread` with an existing instance of `T`.
68    ///
69    /// Any further access of `T` instances via the `InstancePerThread` (or its clones) will return
70    /// instances of `T` from the same family.
71    #[expect(
72        clippy::needless_pass_by_value,
73        reason = "intentional needless consume to encourage all access to go via InstancePerThread<T>"
74    )]
75    #[must_use]
76    pub fn new(inner: T) -> Self {
77        let family = FamilyStateReference::new(inner.family());
78
79        Self { family }
80    }
81
82    /// Returns a `Ref<T>` that can be used to access the current thread's instance of `T`.
83    ///
84    /// Creating multiple concurrent `Ref<T>` instances from the same `InstancePerThread<T>`
85    /// on the same thread is allowed. Every `Ref<T>` instance will reference the same
86    /// instance of `T` per thread.
87    ///
88    /// There are no constraints on the lifetime of the returned `Ref<T>` but it is a
89    /// thread-isolated type and cannot be moved across threads or accessed from a different thread,
90    /// which may impose some limits.
91    ///
92    /// # Example
93    ///
94    /// ```
95    /// # use std::cell::Cell;
96    /// #
97    /// # #[linked::object]
98    /// # struct Thing {
99    /// #     local_value: Cell<usize>,
100    /// # }
101    /// #
102    /// # impl Thing {
103    /// #     pub fn new() -> Self {
104    /// #         linked::new!(Self { local_value: Cell::new(0) })
105    /// #     }
106    /// #
107    /// #     pub fn increment(&self) {
108    /// #        self.local_value.set(self.local_value.get() + 1);
109    /// #     }
110    /// #
111    /// #     pub fn local_value(&self) -> usize {
112    /// #         self.local_value.get()
113    /// #     }
114    /// # }
115    /// #
116    /// use linked::InstancePerThread;
117    ///
118    /// let linked_thing = InstancePerThread::new(Thing::new());
119    ///
120    /// let thing = linked_thing.acquire();
121    /// thing.increment();
122    /// assert_eq!(thing.local_value(), 1);
123    /// ```
124    ///
125    /// # Efficiency
126    ///
127    /// Reuse the returned `Ref<T>` when possible. Every call to this function has
128    /// some overhead, especially if there are no other `Ref<T>` instances from the
129    /// same family active on the current thread.
130    ///
131    /// # Instance lifecycle
132    ///
133    /// A thread-specific instance of `T` is dropped when the last `Ref` on that
134    /// thread is dropped. If a new `Ref` is later obtained, it is initialized
135    /// with a new linked instance of `T` linked to the same family as the
136    /// originating `InstancePerThread<T>`.
137    ///
138    /// ```
139    /// # use std::cell::Cell;
140    /// #
141    /// # #[linked::object]
142    /// # struct Thing {
143    /// #     local_value: Cell<usize>,
144    /// # }
145    /// #
146    /// # impl Thing {
147    /// #     pub fn new() -> Self {
148    /// #         linked::new!(Self { local_value: Cell::new(0) })
149    /// #     }
150    /// #
151    /// #     pub fn increment(&self) {
152    /// #        self.local_value.set(self.local_value.get() + 1);
153    /// #     }
154    /// #
155    /// #     pub fn local_value(&self) -> usize {
156    /// #         self.local_value.get()
157    /// #     }
158    /// # }
159    /// #
160    /// use linked::InstancePerThread;
161    ///
162    /// let linked_thing = InstancePerThread::new(Thing::new());
163    ///
164    /// let thing = linked_thing.acquire();
165    /// thing.increment();
166    /// assert_eq!(thing.local_value(), 1);
167    ///
168    /// drop(thing);
169    ///
170    /// // Dropping the only acquired instance above will have reset the thread-local state.
171    /// let thing = linked_thing.acquire();
172    /// assert_eq!(thing.local_value(), 0);
173    /// ```
174    ///
175    /// To minimize the effort spent on re-creating the thread-local state, ensure that you reuse
176    /// the `Ref<T>` instances as much as possible.
177    ///
178    /// # Thread safety
179    ///
180    /// The returned value is single-threaded and cannot be moved or used across threads. To extend
181    /// the linked object family across threads, transfer `InstancePerThread<T>` instances across threads.
182    /// You can obtain additional `InstancePerThread<T>` instances by cloning the original. Every clone
183    /// is equivalent.
184    #[must_use]
185    pub fn acquire(&self) -> Ref<T> {
186        let inner = self.family.current_thread_instance();
187
188        Ref {
189            inner,
190            family: self.family.clone(),
191        }
192    }
193}
194
195impl<T> Clone for InstancePerThread<T>
196where
197    T: linked::Object,
198{
199    #[inline]
200    fn clone(&self) -> Self {
201        Self {
202            family: self.family.clone(),
203        }
204    }
205}
206
207/// An acquired thread-local instance of a linked object of type `T`,
208/// implementing `Deref<Target = T>`.
209///
210/// For details, see [`InstancePerThread<T>`][InstancePerThread] which is the type used
211/// to create instances of `Ref<T>`.
212#[derive(Debug)]
213pub struct Ref<T>
214where
215    T: linked::Object,
216{
217    // We really are just a wrapper around an Rc<T>. The only other duty we have
218    // is to clean up the thread-local instance when the last `Ref` is dropped.
219    inner: Rc<T>,
220    family: FamilyStateReference<T>,
221}
222
223impl<T> Deref for Ref<T>
224where
225    T: linked::Object,
226{
227    type Target = T;
228
229    #[inline]
230    fn deref(&self) -> &Self::Target {
231        &self.inner
232    }
233}
234
235impl<T> Clone for Ref<T>
236where
237    T: linked::Object,
238{
239    #[inline]
240    fn clone(&self) -> Self {
241        Self {
242            inner: Rc::clone(&self.inner),
243            family: self.family.clone(),
244        }
245    }
246}
247
248impl<T> Drop for Ref<T>
249where
250    T: linked::Object,
251{
252    fn drop(&mut self) {
253        // If we were the last Ref on this thread then we need to drop the thread-local
254        // state for this thread. Note that there are 2 references - ourselves and the family state.
255        if Rc::strong_count(&self.inner) != 2 {
256            // No - there is another Ref, so we do not need to clean up.
257            return;
258        }
259
260        self.family.clear_current_thread_instance();
261
262        // `self.inner` is now the last reference to the current thread's instance of T
263        // and this instance will be dropped once this function returns and drops the last `Rc<T>`.
264    }
265}
266
267/// One reference to the state of a specific family of per-thread linked objects.
268/// This can be used to retrieve and/or initialize the current thread's instance.
269#[derive(Debug)]
270struct FamilyStateReference<T>
271where
272    T: linked::Object,
273{
274    // If a thread needs a new instance, we create it via the family.
275    family: linked::Family<T>,
276
277    // We store the state of each thread here. See safety comments on ThreadSpecificState!
278    // NB! While it is legal to manipulate the HashMap from any thread, including to move
279    // the values, calling actual functions on a value is only valid from the thread in the key.
280    //
281    // To ensure safety, we must also ensure that all values are removed from here before the map
282    // is dropped, because each value must be dropped on the thread that created it and dropping is
283    // logic executed on that thread-specific value!
284    //
285    // This is done in the `Ref` destructor. By the time this map is dropped,
286    // it must be empty, which we assert in our own drop().
287    //
288    // The write lock here is only held when initializing the thread-specific state for a thread
289    // for the first time, which should generally be rare, especially as user code will also be
290    // motivated to reduce those instances because it also means initializing the actual `T` inside.
291    // Most access will therefore only need to take a read lock.
292    thread_specific: Arc<RwLock<HashMap<ThreadId, ThreadSpecificState<T>, BuildThreadIdHasher>>>,
293}
294
295impl<T> FamilyStateReference<T>
296where
297    T: linked::Object,
298{
299    #[must_use]
300    fn new(family: linked::Family<T>) -> Self {
301        Self {
302            family,
303            thread_specific: Arc::new(RwLock::new(HashMap::with_hasher(BuildThreadIdHasher))),
304        }
305    }
306
307    /// Returns the `Rc<T>` for the current thread, creating it if necessary.
308    #[must_use]
309    fn current_thread_instance(&self) -> Rc<T> {
310        let thread_id = thread::current().id();
311
312        // First, an optimistic pass - let us assume it is already initialized for our thread.
313        {
314            let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
315
316            if let Some(state) = map.get(&thread_id) {
317                // SAFETY: We must guarantee that we are on the thread that owns
318                // the thread-specific state. We are - thread ID lookup led us here.
319                return unsafe { state.clone_instance() };
320            }
321        }
322
323        // The state for the current thread is not yet initialized. Let us initialize!
324        // Note that we create this instance outside any locks, both to reduce the
325        // lock durations but also because cloning a linked object may execute arbitrary code,
326        // including potentially code that tries to grab the same lock.
327        let instance: Rc<T> = Rc::new(self.family.clone().into());
328
329        // Let us add the new instance to the map.
330        let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
331
332        // In some wild corner cases, it is perhaps possible that the arbitrary code in the
333        // linked object clone logic may already have filled the map with our value? It is
334        // a bit of a stretch of imagination but let us accept the possibility to be thorough.
335        match map.entry(thread_id) {
336            hash_map::Entry::Occupied(occupied_entry) => {
337                // There already is something in the entry. That is fine, we just ignore the
338                // new instance we created and pretend we are on the optimistic path.
339                let state = occupied_entry.get();
340
341                // SAFETY: We must guarantee that we are on the thread that owns
342                // the thread-specific state. We are - thread ID lookup led us here.
343                unsafe { state.clone_instance() }
344            }
345            hash_map::Entry::Vacant(vacant_entry) => {
346                // We are the first thread to create an instance. Let us insert it.
347                // SAFETY: We must guarantee that any further access (taking the Rc or dropping)
348                // takes place on the same thread as was used to call this function. We ensure this
349                // by the thread ID lookup in the map key - we can only ever directly access map
350                // entries owned by the current thread (though we may resize the map from any
351                // thread, as it simply moves data in memory).
352                let state = unsafe { ThreadSpecificState::new(Rc::clone(&instance)) };
353                vacant_entry.insert(state);
354
355                instance
356            }
357        }
358    }
359
360    fn clear_current_thread_instance(&self) {
361        // We need to clear the thread-specific state for this thread.
362        let thread_id = thread::current().id();
363
364        let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
365        map.remove(&thread_id);
366    }
367}
368
369impl<T> Clone for FamilyStateReference<T>
370where
371    T: linked::Object,
372{
373    fn clone(&self) -> Self {
374        Self {
375            family: self.family.clone(),
376            thread_specific: Arc::clone(&self.thread_specific),
377        }
378    }
379}
380
381impl<T> Drop for FamilyStateReference<T>
382where
383    T: linked::Object,
384{
385    #[cfg_attr(test, mutants::skip)] // This is just a sanity check, no functional behavior.
386    fn drop(&mut self) {
387        // If we are the last reference to the family state, this will drop the thread-specific map.
388        // We need to ensure that the thread-specific state is empty before we drop the map.
389        // This is a sanity check - if this fails, we have a defect somewhere in our code.
390
391        if Arc::strong_count(&self.thread_specific) > 1 {
392            // We are not the last reference to the family state,
393            // so no state dropping will occur - having state in the map is fine.
394            return;
395        }
396
397        if thread::panicking() {
398            // If we are already panicking, there is no point in asserting anything,
399            // as another panic may disrupt the handling of the original panic.
400            return;
401        }
402
403        let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
404        assert!(
405            map.is_empty(),
406            "thread-specific state map was not empty on drop - internal logic error"
407        );
408    }
409}
410
411/// Holds the thread-specific state for a specific family of per-thread linked objects.
412///
413/// # Safety
414///
415/// This contains an `Rc`, which is `!Send` and only meant to be accessed from the thread it was
416/// created on. Yet the instance of this type itself is visible from multiple threads and
417/// potentially even touched (moved) from another thread when resizing the `HashMap` of all
418/// instances! How can this be?!
419///
420/// We take advantage of the fact that an `Rc` is merely a reference to a control block.
421/// As long as we never touch the control block from the wrong thread, nobody will ever
422/// know we touched the `Rc` from another thread. This allows us to move the Rc around
423/// in memory as long as the move itself is synchronized.
424///
425/// Obviously, this relies on `Rc` implementation details, so we are somewhat at risk of
426/// breakage if a future Rust std implementation changes the way `Rc` works but this seems
427/// unlikely as this is fairly fundamental to the nature of how smart pointers are created.
428///
429/// NB! We must not drop the Rc (and by extension this type) from a foreign thread!
430#[derive(Debug)]
431struct ThreadSpecificState<T>
432where
433    T: linked::Object,
434{
435    instance: Rc<T>,
436}
437
438impl<T> ThreadSpecificState<T>
439where
440    T: linked::Object,
441{
442    /// Creates a new `ThreadSpecificState` with the given `Rc<T>`.
443    ///
444    /// # Safety
445    ///
446    /// The caller must guarantee that any further access (including dropping) takes place on the
447    /// same thread as was used to call this function.
448    ///
449    /// See type-level safety comments for details.
450    #[must_use]
451    unsafe fn new(instance: Rc<T>) -> Self {
452        Self { instance }
453    }
454
455    /// Returns the `Rc<T>` for this thread.
456    ///
457    /// # Safety
458    ///
459    /// The caller must guarantee that the current thread is the thread for which this
460    /// `ThreadSpecificState` was created. This is not enforced by the type system.
461    ///
462    /// See type-level safety comments for details.
463    #[must_use]
464    unsafe fn clone_instance(&self) -> Rc<T> {
465        Rc::clone(&self.instance)
466    }
467}
468
469// SAFETY: See comments on type.
470unsafe impl<T> Sync for ThreadSpecificState<T> where T: linked::Object {}
471// SAFETY: See comments on type.
472unsafe impl<T> Send for ThreadSpecificState<T> where T: linked::Object {}
473
474#[cfg(test)]
475#[cfg_attr(coverage_nightly, coverage(off))]
476mod tests {
477    use std::cell::Cell;
478    use std::sync::{Arc, Mutex};
479    use std::thread;
480
481    use super::*;
482
483    #[linked::object]
484    struct TokenCache {
485        shared_value: Arc<Mutex<usize>>,
486        local_value: Cell<usize>,
487    }
488
489    impl TokenCache {
490        fn new() -> Self {
491            let shared_value = Arc::new(Mutex::new(0));
492
493            linked::new!(Self {
494                shared_value: Arc::clone(&shared_value),
495                local_value: Cell::new(0),
496            })
497        }
498
499        fn increment(&self) {
500            self.local_value.set(self.local_value.get().wrapping_add(1));
501
502            let mut shared_value = self.shared_value.lock().unwrap();
503            *shared_value = shared_value.wrapping_add(1);
504        }
505
506        fn local_value(&self) -> usize {
507            self.local_value.get()
508        }
509
510        fn shared_value(&self) -> usize {
511            *self.shared_value.lock().unwrap()
512        }
513    }
514
515    #[test]
516    fn per_thread_smoke_test() {
517        let linked_cache = InstancePerThread::new(TokenCache::new());
518
519        let cache1 = linked_cache.acquire();
520        cache1.increment();
521
522        assert_eq!(cache1.local_value(), 1);
523        assert_eq!(cache1.shared_value(), 1);
524
525        // This must refer to the same instance.
526        let cache2 = linked_cache.acquire();
527
528        assert_eq!(cache2.local_value(), 1);
529        assert_eq!(cache2.shared_value(), 1);
530
531        cache2.increment();
532
533        assert_eq!(cache1.local_value(), 2);
534        assert_eq!(cache1.shared_value(), 2);
535
536        thread::spawn(move || {
537            // You can move InstancePerThread across threads.
538            let cache3 = linked_cache.acquire();
539
540            // This is a different thread's instance, so the local value is fresh.
541            assert_eq!(cache3.local_value(), 0);
542            assert_eq!(cache3.shared_value(), 2);
543
544            cache3.increment();
545
546            assert_eq!(cache3.local_value(), 1);
547            assert_eq!(cache3.shared_value(), 3);
548
549            // You can clone this and every clone works the same.
550            let thread_local_clone = linked_cache.clone();
551
552            let cache4 = thread_local_clone.acquire();
553
554            assert_eq!(cache4.local_value(), 1);
555            assert_eq!(cache4.shared_value(), 3);
556
557            // Every InstancePerThread instance from the same family is equivalent.
558            let cache5 = linked_cache.acquire();
559
560            assert_eq!(cache5.local_value(), 1);
561            assert_eq!(cache5.shared_value(), 3);
562
563            thread::spawn(move || {
564                let cache6 = thread_local_clone.acquire();
565
566                // This is a different thread's instance, so the local value is fresh.
567                assert_eq!(cache6.local_value(), 0);
568                assert_eq!(cache6.shared_value(), 3);
569
570                cache6.increment();
571
572                assert_eq!(cache6.local_value(), 1);
573                assert_eq!(cache6.shared_value(), 4);
574            })
575            .join()
576            .unwrap();
577        })
578        .join()
579        .unwrap();
580
581        assert_eq!(cache1.local_value(), 2);
582        assert_eq!(cache1.shared_value(), 4);
583    }
584
585    #[test]
586    fn thread_state_dropped_on_last_thread_local_drop() {
587        let linked_cache = InstancePerThread::new(TokenCache::new());
588
589        let cache = linked_cache.acquire();
590        cache.increment();
591
592        assert_eq!(cache.local_value(), 1);
593
594        // This will drop the local state.
595        drop(cache);
596
597        // We get a fresh instance now, initialized from scratch for this thread.
598        let cache = linked_cache.acquire();
599        assert_eq!(cache.local_value(), 0);
600    }
601
602    #[test]
603    fn thread_state_dropped_on_thread_exit() {
604        // At the start, no thread-specific state has been created. The link embedded into the
605        // InstancePerThread holds one reference to the inner shared value of the TokenCache.
606        let linked_cache = InstancePerThread::new(TokenCache::new());
607
608        let cache = linked_cache.acquire();
609
610        // We now have two references to the inner shared value - the link + this fn.
611        assert_eq!(Arc::strong_count(&cache.shared_value), 2);
612
613        thread::spawn(move || {
614            let cache = linked_cache.acquire();
615
616            assert_eq!(Arc::strong_count(&cache.shared_value), 3);
617        })
618        .join()
619        .unwrap();
620
621        // Should be back to 2 here - the thread-local state was dropped when the thread exited.
622        assert_eq!(Arc::strong_count(&cache.shared_value), 2);
623    }
624}