linked/
per_thread.rs

1use std::{
2    collections::{HashMap, hash_map},
3    ops::Deref,
4    rc::Rc,
5    sync::{Arc, RwLock},
6    thread::{self, ThreadId},
7};
8
9use simple_mermaid::mermaid;
10
11use crate::{BuildThreadIdHasher, ERR_POISONED_LOCK};
12
13/// A wrapper that manages instances of linked objects of type `T`, ensuring that only one
14/// instance of `T` is created per thread.
15///
16/// This is a conceptual equivalent of the [`linked::instance_per_thread!` macro][1], with the main
17/// difference being that this type operates entirely at runtime using dynamic storage and does
18/// not require a static variable to be defined.
19///
20/// # Usage
21///
22/// Create an instance of `PerThread` and provide it an initial instance of a linked object `T`.
23/// This initial instance will be used to create additional instances on demand. Any instance
24/// of `T` retrieved through the same `PerThread` or a clone will be linked to the same family
25/// of `T` instances.
26///
27#[ doc=mermaid!( "../doc/per_thread.mermaid") ]
28///
29/// To access the current thread's instance of `T`, you must first obtain a
30/// [`ThreadLocal<T>`][ThreadLocal] which works in a manner similar to `Rc<T>`, allowing you to
31/// reference the value within. You can obtain a [`ThreadLocal<T>`][ThreadLocal] by calling
32/// [`PerThread::local()`][Self::local].
33///
34/// Once you have a [`ThreadLocal<T>`][ThreadLocal], you can access the `T` within by simply
35/// dereferencing via the `Deref<Target = T>` trait.
36///
37/// # Long-lived thread-specific instances
38///
39/// Note that the `ThreadLocal` type is `!Send`, which means you cannot store it in places that
40/// need to be thread-mobile. For example, in web framework request handlers the compiler might
41/// not permit you to let a `ThreadLocal` live across an `await`, depending on the web framework,
42/// the async task runtime used and its specific configuration.
43///
44/// # Resource management
45///
46/// A thread-specific instance of `T` is dropped when the last `ThreadLocal` on that thread is
47/// dropped. If a new `ThreadLocal` is later obtained, it is initialized with a new instance
48/// of the linked object.
49///
50/// It is important to emphasize that this means if you only create temporary `ThreadLocal`
51/// instances then you will get a new instance of `T` every time. The performance impact of
52/// this depends on how `T` works internally but you are recommended to keep `ThreadLocal`
53/// instances around for reuse when possible.
54///
55/// # Advanced scenarios
56///
57/// Use of `PerThread` does not close the door on other ways to use linked objects.
58/// For example, you always have the possibility of manually taking the `T` and creating
59/// additional clones of it to break out the one-per-thread limitation. The `PerThread` type
60/// only controls what happens through the `PerThread` type.
61///
62/// [1]: crate::instance_per_thread
63#[derive(Debug)]
64pub struct PerThread<T>
65where
66    T: linked::Object,
67{
68    family: FamilyStateReference<T>,
69}
70
71impl<T> PerThread<T>
72where
73    T: linked::Object,
74{
75    /// Creates a new `PerThread` with an existing instance of `T`. Any further access to the `T`
76    /// via the `PerThread` (or its clones) will return instances of `T` from the same family.
77    #[expect(
78        clippy::needless_pass_by_value,
79        reason = "intentional needless consume to encourage all access to go via ThreadLocal<T>"
80    )]
81    #[must_use]
82    pub fn new(inner: T) -> Self {
83        let family = FamilyStateReference::new(inner.handle());
84
85        Self { family }
86    }
87
88    /// Returns a `ThreadLocal<T>` that can be used to efficiently access the current
89    /// thread's `T` instance.
90    ///
91    /// # Example
92    ///
93    /// ```
94    /// # use std::cell::Cell;
95    /// #
96    /// # #[linked::object]
97    /// # struct Thing {
98    /// #     local_value: Cell<usize>,
99    /// # }
100    /// #
101    /// # impl Thing {
102    /// #     pub fn new() -> Self {
103    /// #         linked::new!(Self { local_value: Cell::new(0) })
104    /// #     }
105    /// #
106    /// #     pub fn increment(&self) {
107    /// #        self.local_value.set(self.local_value.get() + 1);
108    /// #     }
109    /// #
110    /// #     pub fn local_value(&self) -> usize {
111    /// #         self.local_value.get()
112    /// #     }
113    /// # }
114    /// #
115    /// let per_thread_thing = linked::PerThread::new(Thing::new());
116    ///
117    /// let local_thing = per_thread_thing.local();
118    /// local_thing.increment();
119    /// assert_eq!(local_thing.local_value(), 1);
120    /// ```
121    ///
122    /// # Efficiency
123    ///
124    /// Reuse the returned instance as much as possible. Every call to this function has some
125    /// overhead, especially if there are no other `ThreadLocal<T>` instances from the same family
126    /// active on the current thread.
127    ///
128    /// # Instance lifecycle
129    ///
130    /// A thread-specific instance of `T` is dropped when the last `ThreadLocal` on that thread is
131    /// dropped. If a new `ThreadLocal` is later obtained, it is initialized with a new instance
132    /// of the linked object.
133    ///
134    /// ```
135    /// # use std::cell::Cell;
136    /// #
137    /// # #[linked::object]
138    /// # struct Thing {
139    /// #     local_value: Cell<usize>,
140    /// # }
141    /// #
142    /// # impl Thing {
143    /// #     pub fn new() -> Self {
144    /// #         linked::new!(Self { local_value: Cell::new(0) })
145    /// #     }
146    /// #
147    /// #     pub fn increment(&self) {
148    /// #        self.local_value.set(self.local_value.get() + 1);
149    /// #     }
150    /// #
151    /// #     pub fn local_value(&self) -> usize {
152    /// #         self.local_value.get()
153    /// #     }
154    /// # }
155    /// #
156    /// let per_thread_thing = linked::PerThread::new(Thing::new());
157    ///
158    /// let local_thing = per_thread_thing.local();
159    /// local_thing.increment();
160    /// assert_eq!(local_thing.local_value(), 1);
161    ///
162    /// drop(local_thing);
163    ///
164    /// // Dropping the only thread-local instance above will have reset the thread-local state.
165    /// let local_thing = per_thread_thing.local();
166    /// assert_eq!(local_thing.local_value(), 0);
167    /// ```
168    ///
169    /// To minimize the effort spent on re-creating the thread-local state, ensure that you reuse
170    /// the `ThreadLocal<T>` instances as much as possible.
171    ///
172    /// # Thread safety
173    ///
174    /// The returned value is single-threaded and cannot be moved or used across threads. For
175    /// transfer across threads, you need to preserve and share/send a `PerThread<T>` instance.
176    #[must_use]
177    pub fn local(&self) -> ThreadLocal<T> {
178        let inner = self.family.current_thread_instance();
179
180        ThreadLocal {
181            inner,
182            family: self.family.clone(),
183        }
184    }
185}
186
187impl<T> Clone for PerThread<T>
188where
189    T: linked::Object,
190{
191    #[inline]
192    fn clone(&self) -> Self {
193        Self {
194            family: self.family.clone(),
195        }
196    }
197}
198
199/// A thread-local instance of a linked object of type `T`. This acts in a manner similar to
200/// `Rc<T>` for a type `T` that implements the [linked object pattern][crate].
201///
202/// For details, see [`PerThread<T>`][PerThread] which is the type used to create instances
203/// of `ThreadLocal<T>`.
204#[derive(Debug)]
205pub struct ThreadLocal<T>
206where
207    T: linked::Object,
208{
209    // We really are just a wrapper around an Rc<T>. The only other duty we have
210    // is to clean up the thread-local instance when the last ThreadLocal is dropped.
211    inner: Rc<T>,
212    family: FamilyStateReference<T>,
213}
214
215impl<T> Deref for ThreadLocal<T>
216where
217    T: linked::Object,
218{
219    type Target = T;
220
221    #[inline]
222    fn deref(&self) -> &Self::Target {
223        &self.inner
224    }
225}
226
227impl<T> Clone for ThreadLocal<T>
228where
229    T: linked::Object,
230{
231    #[inline]
232    fn clone(&self) -> Self {
233        Self {
234            inner: Rc::clone(&self.inner),
235            family: self.family.clone(),
236        }
237    }
238}
239
240impl<T> Drop for ThreadLocal<T>
241where
242    T: linked::Object,
243{
244    fn drop(&mut self) {
245        // If we were the last ThreadLocal on this thread then we need to drop the thread-local
246        // state for this thread. Note that there are 2 references - ourselves and the family state.
247        if Rc::strong_count(&self.inner) != 2 {
248            // No - there is another ThreadLocal, so we do not need to clean up.
249            return;
250        }
251
252        self.family.clear_current_thread_instance();
253
254        // `self.inner` is now the last reference to the current thread's instance of T
255        // and this instance will be dropped once this function returns and drops the last `Rc<T>`.
256    }
257}
258
259/// One reference to the state of a specific family of per-thread linked objects.
260/// This can be used to retrieve and/or initialize the current thread's instance.
261#[derive(Debug)]
262struct FamilyStateReference<T>
263where
264    T: linked::Object,
265{
266    // If a thread needs a new instance, we create it via this handle.
267    handle: linked::Handle<T>,
268
269    // We store the state of each thread here. See safety comments on ThreadSpecificState!
270    // NB! While it is legal to manipulate the HashMap from any thread, including to move
271    // the values, calling actual functions on a value is only valid from the thread in the key.
272    //
273    // To ensure safety, we must also ensure that all values are removed from here before the map
274    // is dropped, because each value must be dropped on the thread that created it and dropping is
275    // logic executed on that thread-specific value!
276    //
277    // This is done in the `ThreadLocal` destructor. By the time this map is dropped, it must be
278    // empty, which we assert in our own drop().
279    //
280    // The write lock here is only held when initializing the thread-specific state for a thread
281    // for the first time, which should generally be rare, especially as user code will also be
282    // motivated to reduce those instances because it also means initializing the actual `T` inside.
283    // Most access will therefore only need to take a read lock.
284    thread_specific: Arc<RwLock<HashMap<ThreadId, ThreadSpecificState<T>, BuildThreadIdHasher>>>,
285}
286
287impl<T> FamilyStateReference<T>
288where
289    T: linked::Object,
290{
291    #[must_use]
292    fn new(handle: linked::Handle<T>) -> Self {
293        Self {
294            handle,
295            thread_specific: Arc::new(RwLock::new(HashMap::with_hasher(BuildThreadIdHasher))),
296        }
297    }
298
299    /// Returns the `Rc<T>` for the current thread, creating it if necessary.
300    #[must_use]
301    fn current_thread_instance(&self) -> Rc<T> {
302        let thread_id = thread::current().id();
303
304        // First, an optimistic pass - let's assume it is already initialized for our thread.
305        {
306            let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
307
308            if let Some(state) = map.get(&thread_id) {
309                // SAFETY: We must guarantee that we are on the thread that owns
310                // the thread-specific state. We are - thread ID lookup led us here.
311                return unsafe { state.clone_instance() };
312            }
313        }
314
315        // The state for the current thread is not yet initialized. Let's initialize!
316        // Note that we create this instance outside any locks, both to reduce the
317        // lock durations but also because cloning a linked object may execute arbitrary code,
318        // including potentially code that tries to grab the same lock.
319        let instance: Rc<T> = Rc::new(self.handle.clone().into());
320
321        // Let's add the new instance to the map.
322        let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
323
324        // In some wild corner cases, it is perhaps possible that the arbitrary code in the
325        // linked object clone logic may already have filled the map with our value? It is
326        // a bit of a stretch of imagination but let's accept the possibility to be thorough.
327        match map.entry(thread_id) {
328            hash_map::Entry::Occupied(occupied_entry) => {
329                // There already is something in the entry. That's fine, we just ignore the
330                // new instance we created and pretend we are on the optimistic path.
331                let state = occupied_entry.get();
332
333                // SAFETY: We must guarantee that we are on the thread that owns
334                // the thread-specific state. We are - thread ID lookup led us here.
335                unsafe { state.clone_instance() }
336            }
337            hash_map::Entry::Vacant(vacant_entry) => {
338                // We are the first thread to create an instance. Let's insert it.
339                // SAFETY: We must guarantee that any further access (taking the Rc or dropping)
340                // takes place on the same thread as was used to call this function. We ensure this
341                // by the thread ID lookup in the map key - we can only ever directly access map
342                // entries owned by the current thread (though we may resize the map from any
343                // thread, as it simply moves data in memory).
344                let state = unsafe { ThreadSpecificState::new(Rc::clone(&instance)) };
345                vacant_entry.insert(state);
346
347                instance
348            }
349        }
350    }
351
352    fn clear_current_thread_instance(&self) {
353        // We need to clear the thread-specific state for this thread.
354        let thread_id = thread::current().id();
355
356        let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
357        map.remove(&thread_id);
358    }
359}
360
361impl<T> Clone for FamilyStateReference<T>
362where
363    T: linked::Object,
364{
365    fn clone(&self) -> Self {
366        Self {
367            handle: self.handle.clone(),
368            thread_specific: Arc::clone(&self.thread_specific),
369        }
370    }
371}
372
373impl<T> Drop for FamilyStateReference<T>
374where
375    T: linked::Object,
376{
377    #[cfg_attr(test, mutants::skip)] // This is just a sanity check, no functional behavior.
378    fn drop(&mut self) {
379        // If we are the last reference to the family state, this will drop the thread-specific map.
380        // We need to ensure that the thread-specific state is empty before we drop the map.
381        // This is a sanity check - if this fails, we have a defect somewhere in our code.
382
383        if Arc::strong_count(&self.thread_specific) > 1 {
384            // We are not the last reference to the family state,
385            // so no state dropping will occur - having state in the map is fine.
386            return;
387        }
388
389        let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
390        assert!(
391            map.is_empty(),
392            "thread-specific state map was not empty on drop - internal logic error"
393        );
394    }
395}
396
397/// Holds the thread-specific state for a specific family of per-thread linked objects.
398///
399/// # Safety
400///
401/// This contains an `Rc`, which is `!Send` and only meant to be accessed from the thread it was
402/// created on. Yet the instance of this type itself is visible from multiple threads and
403/// potentially even touched (moved) from another thread when resizing the `HashMap` of all
404/// instances! How can this be?!
405///
406/// We take advantage of the fact that an `Rc` is merely a reference to a control block.
407/// As long as we never touch the control block from the wrong thread, nobody will ever
408/// know we touched the `Rc` from another thread. This allows us to move the Rc around
409/// in memory as long as the move itself is synchronized.
410///
411/// Obviously, this relies on `Rc` implementation details, so we are somewhat at risk of
412/// breakage if a future Rust std implementation changes the way `Rc` works but this seems
413/// unlikely as this is fairly fundamental to the nature of how smart pointers are created.
414///
415/// NB! We must not drop the Rc (and by extension this type) from a foreign thread!
416#[derive(Debug)]
417struct ThreadSpecificState<T>
418where
419    T: linked::Object,
420{
421    instance: Rc<T>,
422}
423
424impl<T> ThreadSpecificState<T>
425where
426    T: linked::Object,
427{
428    /// Creates a new `ThreadSpecificState` with the given `Rc<T>`.
429    ///
430    /// # Safety
431    ///
432    /// The caller must guarantee that any further access (including dropping) takes place on the
433    /// same thread as was used to call this function.
434    ///
435    /// See type-level safety comments for details.
436    #[must_use]
437    unsafe fn new(instance: Rc<T>) -> Self {
438        Self { instance }
439    }
440
441    /// Returns the `Rc<T>` for this thread.
442    ///
443    /// # Safety
444    ///
445    /// The caller must guarantee that the current thread is the thread for which this
446    /// `ThreadSpecificState` was created. This is not enforced by the type system.
447    ///
448    /// See type-level safety comments for details.
449    #[must_use]
450    unsafe fn clone_instance(&self) -> Rc<T> {
451        Rc::clone(&self.instance)
452    }
453}
454
455// SAFETY: See comments on type.
456unsafe impl<T> Sync for ThreadSpecificState<T> where T: linked::Object {}
457// SAFETY: See comments on type.
458unsafe impl<T> Send for ThreadSpecificState<T> where T: linked::Object {}
459
460#[cfg(test)]
461mod tests {
462    use std::{
463        cell::Cell,
464        sync::{Arc, Mutex},
465        thread,
466    };
467
468    use super::*;
469
470    #[linked::object]
471    struct TokenCache {
472        shared_value: Arc<Mutex<usize>>,
473        local_value: Cell<usize>,
474    }
475
476    impl TokenCache {
477        fn new() -> Self {
478            #[expect(
479                clippy::mutex_atomic,
480                reason = "inner type is placeholder, for realistic usage"
481            )]
482            let shared_value = Arc::new(Mutex::new(0));
483
484            linked::new!(Self {
485                shared_value: Arc::clone(&shared_value),
486                local_value: Cell::new(0),
487            })
488        }
489
490        fn increment(&self) {
491            self.local_value.set(self.local_value.get().wrapping_add(1));
492
493            let mut shared_value = self.shared_value.lock().unwrap();
494            *shared_value = shared_value.wrapping_add(1);
495        }
496
497        fn local_value(&self) -> usize {
498            self.local_value.get()
499        }
500
501        fn shared_value(&self) -> usize {
502            *self.shared_value.lock().unwrap()
503        }
504    }
505
506    #[test]
507    fn per_thread_smoke_test() {
508        let per_thread = PerThread::new(TokenCache::new());
509
510        let thread_local1 = per_thread.local();
511        thread_local1.increment();
512
513        assert_eq!(thread_local1.local_value(), 1);
514        assert_eq!(thread_local1.shared_value(), 1);
515
516        // This must refer to the same instance.
517        let thread_local2 = per_thread.local();
518
519        assert_eq!(thread_local2.local_value(), 1);
520        assert_eq!(thread_local2.shared_value(), 1);
521
522        thread_local2.increment();
523
524        assert_eq!(thread_local1.local_value(), 2);
525        assert_eq!(thread_local1.shared_value(), 2);
526
527        thread::spawn(move || {
528            // You can move PerThread across threads.
529            let thread_local3 = per_thread.local();
530
531            // This is a different thread's instance, so the local value is fresh.
532            assert_eq!(thread_local3.local_value(), 0);
533            assert_eq!(thread_local3.shared_value(), 2);
534
535            thread_local3.increment();
536
537            assert_eq!(thread_local3.local_value(), 1);
538            assert_eq!(thread_local3.shared_value(), 3);
539
540            // You can clone this and every clone works the same.
541            let per_thread_clone = per_thread.clone();
542
543            let thread_local4 = per_thread_clone.local();
544
545            assert_eq!(thread_local4.local_value(), 1);
546            assert_eq!(thread_local4.shared_value(), 3);
547
548            // Every PerThread instance from the same family is equivalent.
549            let thread_local5 = per_thread.local();
550
551            assert_eq!(thread_local5.local_value(), 1);
552            assert_eq!(thread_local5.shared_value(), 3);
553
554            thread::spawn(move || {
555                let thread_local5 = per_thread_clone.local();
556
557                // This is a different thread's instance, so the local value is fresh.
558                assert_eq!(thread_local5.local_value(), 0);
559                assert_eq!(thread_local5.shared_value(), 3);
560
561                thread_local5.increment();
562
563                assert_eq!(thread_local5.local_value(), 1);
564                assert_eq!(thread_local5.shared_value(), 4);
565            })
566            .join()
567            .unwrap();
568        })
569        .join()
570        .unwrap();
571
572        assert_eq!(thread_local1.local_value(), 2);
573        assert_eq!(thread_local1.shared_value(), 4);
574    }
575
576    #[test]
577    fn thread_state_dropped_on_last_thread_local_drop() {
578        let per_thread = PerThread::new(TokenCache::new());
579
580        let local = per_thread.local();
581        local.increment();
582
583        assert_eq!(local.local_value(), 1);
584
585        // This will drop the local state.
586        drop(local);
587
588        // We get a fresh instance now, initialized from scratch for this thread.
589        let local = per_thread.local();
590        assert_eq!(local.local_value(), 0);
591    }
592
593    #[test]
594    fn thread_state_dropped_on_thread_exit() {
595        // At the start, no thread-specific state has been created. The link embedded into the
596        // PerThread holds one reference to the inner shared value of the TokenCache.
597        let per_thread = PerThread::new(TokenCache::new());
598
599        let local = per_thread.local();
600
601        // We now have two references to the inner shared value - the link + this fn.
602        assert_eq!(Arc::strong_count(&local.shared_value), 2);
603
604        thread::spawn(move || {
605            let local = per_thread.local();
606
607            assert_eq!(Arc::strong_count(&local.shared_value), 3);
608        })
609        .join()
610        .unwrap();
611
612        // Should be back to 2 here - the thread-local state was dropped when the thread exited.
613        assert_eq!(Arc::strong_count(&local.shared_value), 2);
614    }
615}