linked/instance_per_thread.rs
1use std::collections::{HashMap, hash_map};
2use std::ops::Deref;
3use std::rc::Rc;
4use std::sync::{Arc, RwLock};
5use std::thread::{self, ThreadId};
6
7use simple_mermaid::mermaid;
8
9use crate::{BuildThreadIdHasher, ERR_POISONED_LOCK};
10
11/// A wrapper that manages linked instances of `T`, ensuring that only one
12/// instance of `T` is created per thread.
13///
14/// This is similar to the [`linked::thread_local_rc!` macro][1], with the main difference
15/// being that this type operates entirely at runtime using dynamic storage and does
16/// not require a static variable to be defined.
17///
18/// # Usage
19///
20/// Create an instance of `InstancePerThread` and provide it the initial instance of a linked
21/// object `T`. Any instance of `T` accessed through the same `InstancePerThread` or a clone of it
22/// will be linked to the same family.
23#[ doc=mermaid!( "../doc/instance_per_thread.mermaid") ]
24///
25/// To access the current thread's instance of `T`, you must first obtain a
26/// [`Ref<T>`][Ref] by calling [`.acquire()`][Self::acquire]. Then you can
27/// access the `T` within by simply dereferencing via the `Deref<Target = T>` trait.
28///
29/// `Ref<T>` is a thread-isolated type, meaning you cannot move it to a different
30/// thread nor access it from a different thread. To access linked instances on other threads,
31/// you must transfer the `InstancePerThread<T>` instance across threads and obtain a new `Ref<T>`
32/// on the destination thread.
33///
34/// # Resource management
35///
36/// A thread-specific instance of `T` is dropped when the last `Ref` on that thread
37/// is dropped, similar to how `Rc<T>` would behave. If a new `Ref` is later obtained,
38/// it is initialized with a new instance of the linked object.
39///
40/// It is important to emphasize that this means if you only acquire temporary `Ref`
41/// objects then you will get a new instance of `T` every time. The performance impact of
42/// this depends on how `T` works internally but you are recommended to keep `Ref`
43/// instances around for reuse when possible.
44///
45/// # `Ref` storage
46///
47/// `Ref` is a thread-isolated type, which means you cannot store it in places that require types
48/// to be thread-mobile. For example, in web framework request handlers the compiler might not
49/// permit you to let a `Ref` live across an `await`, depending on the web framework, the async
50/// task runtime used and its specific configuration.
51///
52/// Consider using `InstancePerThreadSync<T>` if you need a thread-mobile variant of `Ref`.
53///
54/// [1]: crate::thread_local_rc
55#[derive(Debug)]
56pub struct InstancePerThread<T>
57where
58 T: linked::Object,
59{
60 family: FamilyStateReference<T>,
61}
62
63impl<T> InstancePerThread<T>
64where
65 T: linked::Object,
66{
67 /// Creates a new `InstancePerThread` with an existing instance of `T`.
68 ///
69 /// Any further access of `T` instances via the `InstancePerThread` (or its clones) will return
70 /// instances of `T` from the same family.
71 #[expect(
72 clippy::needless_pass_by_value,
73 reason = "intentional needless consume to encourage all access to go via InstancePerThread<T>"
74 )]
75 #[must_use]
76 pub fn new(inner: T) -> Self {
77 let family = FamilyStateReference::new(inner.family());
78
79 Self { family }
80 }
81
82 /// Returns a `Ref<T>` that can be used to access the current thread's instance of `T`.
83 ///
84 /// Creating multiple concurrent `Ref<T>` instances from the same `InstancePerThread<T>`
85 /// on the same thread is allowed. Every `Ref<T>` instance will reference the same
86 /// instance of `T` per thread.
87 ///
88 /// There are no constraints on the lifetime of the returned `Ref<T>` but it is a
89 /// thread-isolated type and cannot be moved across threads or accessed from a different thread,
90 /// which may impose some limits.
91 ///
92 /// # Example
93 ///
94 /// ```
95 /// # use std::cell::Cell;
96 /// #
97 /// # #[linked::object]
98 /// # struct Thing {
99 /// # local_value: Cell<usize>,
100 /// # }
101 /// #
102 /// # impl Thing {
103 /// # pub fn new() -> Self {
104 /// # linked::new!(Self { local_value: Cell::new(0) })
105 /// # }
106 /// #
107 /// # pub fn increment(&self) {
108 /// # self.local_value.set(self.local_value.get() + 1);
109 /// # }
110 /// #
111 /// # pub fn local_value(&self) -> usize {
112 /// # self.local_value.get()
113 /// # }
114 /// # }
115 /// #
116 /// use linked::InstancePerThread;
117 ///
118 /// let linked_thing = InstancePerThread::new(Thing::new());
119 ///
120 /// let thing = linked_thing.acquire();
121 /// thing.increment();
122 /// assert_eq!(thing.local_value(), 1);
123 /// ```
124 ///
125 /// # Efficiency
126 ///
127 /// Reuse the returned `Ref<T>` when possible. Every call to this function has
128 /// some overhead, especially if there are no other `Ref<T>` instances from the
129 /// same family active on the current thread.
130 ///
131 /// # Instance lifecycle
132 ///
133 /// A thread-specific instance of `T` is dropped when the last `Ref` on that
134 /// thread is dropped. If a new `Ref` is later obtained, it is initialized
135 /// with a new linked instance of `T` linked to the same family as the
136 /// originating `InstancePerThread<T>`.
137 ///
138 /// ```
139 /// # use std::cell::Cell;
140 /// #
141 /// # #[linked::object]
142 /// # struct Thing {
143 /// # local_value: Cell<usize>,
144 /// # }
145 /// #
146 /// # impl Thing {
147 /// # pub fn new() -> Self {
148 /// # linked::new!(Self { local_value: Cell::new(0) })
149 /// # }
150 /// #
151 /// # pub fn increment(&self) {
152 /// # self.local_value.set(self.local_value.get() + 1);
153 /// # }
154 /// #
155 /// # pub fn local_value(&self) -> usize {
156 /// # self.local_value.get()
157 /// # }
158 /// # }
159 /// #
160 /// use linked::InstancePerThread;
161 ///
162 /// let linked_thing = InstancePerThread::new(Thing::new());
163 ///
164 /// let thing = linked_thing.acquire();
165 /// thing.increment();
166 /// assert_eq!(thing.local_value(), 1);
167 ///
168 /// drop(thing);
169 ///
170 /// // Dropping the only acquired instance above will have reset the thread-local state.
171 /// let thing = linked_thing.acquire();
172 /// assert_eq!(thing.local_value(), 0);
173 /// ```
174 ///
175 /// To minimize the effort spent on re-creating the thread-local state, ensure that you reuse
176 /// the `Ref<T>` instances as much as possible.
177 ///
178 /// # Thread safety
179 ///
180 /// The returned value is single-threaded and cannot be moved or used across threads. To extend
181 /// the linked object family across threads, transfer `InstancePerThread<T>` instances across threads.
182 /// You can obtain additional `InstancePerThread<T>` instances by cloning the original. Every clone
183 /// is equivalent.
184 #[must_use]
185 pub fn acquire(&self) -> Ref<T> {
186 let inner = self.family.current_thread_instance();
187
188 Ref {
189 inner,
190 family: self.family.clone(),
191 }
192 }
193}
194
195impl<T> Clone for InstancePerThread<T>
196where
197 T: linked::Object,
198{
199 #[inline]
200 fn clone(&self) -> Self {
201 Self {
202 family: self.family.clone(),
203 }
204 }
205}
206
207/// An acquired thread-local instance of a linked object of type `T`,
208/// implementing `Deref<Target = T>`.
209///
210/// For details, see [`InstancePerThread<T>`][InstancePerThread] which is the type used
211/// to create instances of `Ref<T>`.
212#[derive(Debug)]
213pub struct Ref<T>
214where
215 T: linked::Object,
216{
217 // We really are just a wrapper around an Rc<T>. The only other duty we have
218 // is to clean up the thread-local instance when the last `Ref` is dropped.
219 inner: Rc<T>,
220 family: FamilyStateReference<T>,
221}
222
223impl<T> Deref for Ref<T>
224where
225 T: linked::Object,
226{
227 type Target = T;
228
229 #[inline]
230 fn deref(&self) -> &Self::Target {
231 &self.inner
232 }
233}
234
235impl<T> Clone for Ref<T>
236where
237 T: linked::Object,
238{
239 #[inline]
240 fn clone(&self) -> Self {
241 Self {
242 inner: Rc::clone(&self.inner),
243 family: self.family.clone(),
244 }
245 }
246}
247
248impl<T> Drop for Ref<T>
249where
250 T: linked::Object,
251{
252 fn drop(&mut self) {
253 // If we were the last Ref on this thread then we need to drop the thread-local
254 // state for this thread. Note that there are 2 references - ourselves and the family state.
255 if Rc::strong_count(&self.inner) != 2 {
256 // No - there is another Ref, so we do not need to clean up.
257 return;
258 }
259
260 self.family.clear_current_thread_instance();
261
262 // `self.inner` is now the last reference to the current thread's instance of T
263 // and this instance will be dropped once this function returns and drops the last `Rc<T>`.
264 }
265}
266
267/// One reference to the state of a specific family of per-thread linked objects.
268/// This can be used to retrieve and/or initialize the current thread's instance.
269#[derive(Debug)]
270struct FamilyStateReference<T>
271where
272 T: linked::Object,
273{
274 // If a thread needs a new instance, we create it via the family.
275 family: linked::Family<T>,
276
277 // We store the state of each thread here. See safety comments on ThreadSpecificState!
278 // NB! While it is legal to manipulate the HashMap from any thread, including to move
279 // the values, calling actual functions on a value is only valid from the thread in the key.
280 //
281 // To ensure safety, we must also ensure that all values are removed from here before the map
282 // is dropped, because each value must be dropped on the thread that created it and dropping is
283 // logic executed on that thread-specific value!
284 //
285 // This is done in the `Ref` destructor. By the time this map is dropped,
286 // it must be empty, which we assert in our own drop().
287 //
288 // The write lock here is only held when initializing the thread-specific state for a thread
289 // for the first time, which should generally be rare, especially as user code will also be
290 // motivated to reduce those instances because it also means initializing the actual `T` inside.
291 // Most access will therefore only need to take a read lock.
292 thread_specific: Arc<RwLock<HashMap<ThreadId, ThreadSpecificState<T>, BuildThreadIdHasher>>>,
293}
294
295impl<T> FamilyStateReference<T>
296where
297 T: linked::Object,
298{
299 #[must_use]
300 fn new(family: linked::Family<T>) -> Self {
301 Self {
302 family,
303 thread_specific: Arc::new(RwLock::new(HashMap::with_hasher(BuildThreadIdHasher))),
304 }
305 }
306
307 /// Returns the `Rc<T>` for the current thread, creating it if necessary.
308 #[must_use]
309 fn current_thread_instance(&self) -> Rc<T> {
310 let thread_id = thread::current().id();
311
312 // First, an optimistic pass - let us assume it is already initialized for our thread.
313 {
314 let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
315
316 if let Some(state) = map.get(&thread_id) {
317 // SAFETY: We must guarantee that we are on the thread that owns
318 // the thread-specific state. We are - thread ID lookup led us here.
319 return unsafe { state.clone_instance() };
320 }
321 }
322
323 // The state for the current thread is not yet initialized. Let us initialize!
324 // Note that we create this instance outside any locks, both to reduce the
325 // lock durations but also because cloning a linked object may execute arbitrary code,
326 // including potentially code that tries to grab the same lock.
327 let instance: Rc<T> = Rc::new(self.family.clone().into());
328
329 // Let us add the new instance to the map.
330 let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
331
332 // In some wild corner cases, it is perhaps possible that the arbitrary code in the
333 // linked object clone logic may already have filled the map with our value? It is
334 // a bit of a stretch of imagination but let us accept the possibility to be thorough.
335 match map.entry(thread_id) {
336 hash_map::Entry::Occupied(occupied_entry) => {
337 // There already is something in the entry. That is fine, we just ignore the
338 // new instance we created and pretend we are on the optimistic path.
339 let state = occupied_entry.get();
340
341 // SAFETY: We must guarantee that we are on the thread that owns
342 // the thread-specific state. We are - thread ID lookup led us here.
343 unsafe { state.clone_instance() }
344 }
345 hash_map::Entry::Vacant(vacant_entry) => {
346 // We are the first thread to create an instance. Let us insert it.
347 // SAFETY: We must guarantee that any further access (taking the Rc or dropping)
348 // takes place on the same thread as was used to call this function. We ensure this
349 // by the thread ID lookup in the map key - we can only ever directly access map
350 // entries owned by the current thread (though we may resize the map from any
351 // thread, as it simply moves data in memory).
352 let state = unsafe { ThreadSpecificState::new(Rc::clone(&instance)) };
353 vacant_entry.insert(state);
354
355 instance
356 }
357 }
358 }
359
360 fn clear_current_thread_instance(&self) {
361 // We need to clear the thread-specific state for this thread.
362 let thread_id = thread::current().id();
363
364 let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
365 map.remove(&thread_id);
366 }
367}
368
369impl<T> Clone for FamilyStateReference<T>
370where
371 T: linked::Object,
372{
373 fn clone(&self) -> Self {
374 Self {
375 family: self.family.clone(),
376 thread_specific: Arc::clone(&self.thread_specific),
377 }
378 }
379}
380
381impl<T> Drop for FamilyStateReference<T>
382where
383 T: linked::Object,
384{
385 #[cfg_attr(test, mutants::skip)] // This is just a sanity check, no functional behavior.
386 fn drop(&mut self) {
387 // If we are the last reference to the family state, this will drop the thread-specific map.
388 // We need to ensure that the thread-specific state is empty before we drop the map.
389 // This is a sanity check - if this fails, we have a defect somewhere in our code.
390
391 if Arc::strong_count(&self.thread_specific) > 1 {
392 // We are not the last reference to the family state,
393 // so no state dropping will occur - having state in the map is fine.
394 return;
395 }
396
397 if thread::panicking() {
398 // If we are already panicking, there is no point in asserting anything,
399 // as another panic may disrupt the handling of the original panic.
400 return;
401 }
402
403 let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
404 assert!(
405 map.is_empty(),
406 "thread-specific state map was not empty on drop - internal logic error"
407 );
408 }
409}
410
411/// Holds the thread-specific state for a specific family of per-thread linked objects.
412///
413/// # Safety
414///
415/// This contains an `Rc`, which is `!Send` and only meant to be accessed from the thread it was
416/// created on. Yet the instance of this type itself is visible from multiple threads and
417/// potentially even touched (moved) from another thread when resizing the `HashMap` of all
418/// instances! How can this be?!
419///
420/// We take advantage of the fact that an `Rc` is merely a reference to a control block.
421/// As long as we never touch the control block from the wrong thread, nobody will ever
422/// know we touched the `Rc` from another thread. This allows us to move the Rc around
423/// in memory as long as the move itself is synchronized.
424///
425/// Obviously, this relies on `Rc` implementation details, so we are somewhat at risk of
426/// breakage if a future Rust std implementation changes the way `Rc` works but this seems
427/// unlikely as this is fairly fundamental to the nature of how smart pointers are created.
428///
429/// NB! We must not drop the Rc (and by extension this type) from a foreign thread!
430#[derive(Debug)]
431struct ThreadSpecificState<T>
432where
433 T: linked::Object,
434{
435 instance: Rc<T>,
436}
437
438impl<T> ThreadSpecificState<T>
439where
440 T: linked::Object,
441{
442 /// Creates a new `ThreadSpecificState` with the given `Rc<T>`.
443 ///
444 /// # Safety
445 ///
446 /// The caller must guarantee that any further access (including dropping) takes place on the
447 /// same thread as was used to call this function.
448 ///
449 /// See type-level safety comments for details.
450 #[must_use]
451 unsafe fn new(instance: Rc<T>) -> Self {
452 Self { instance }
453 }
454
455 /// Returns the `Rc<T>` for this thread.
456 ///
457 /// # Safety
458 ///
459 /// The caller must guarantee that the current thread is the thread for which this
460 /// `ThreadSpecificState` was created. This is not enforced by the type system.
461 ///
462 /// See type-level safety comments for details.
463 #[must_use]
464 unsafe fn clone_instance(&self) -> Rc<T> {
465 Rc::clone(&self.instance)
466 }
467}
468
469// SAFETY: See comments on type.
470unsafe impl<T> Sync for ThreadSpecificState<T> where T: linked::Object {}
471// SAFETY: See comments on type.
472unsafe impl<T> Send for ThreadSpecificState<T> where T: linked::Object {}
473
474#[cfg(test)]
475#[cfg_attr(coverage_nightly, coverage(off))]
476mod tests {
477 use std::cell::Cell;
478 use std::sync::{Arc, Mutex};
479 use std::thread;
480
481 use super::*;
482
483 #[linked::object]
484 struct TokenCache {
485 shared_value: Arc<Mutex<usize>>,
486 local_value: Cell<usize>,
487 }
488
489 impl TokenCache {
490 fn new() -> Self {
491 let shared_value = Arc::new(Mutex::new(0));
492
493 linked::new!(Self {
494 shared_value: Arc::clone(&shared_value),
495 local_value: Cell::new(0),
496 })
497 }
498
499 fn increment(&self) {
500 self.local_value.set(self.local_value.get().wrapping_add(1));
501
502 let mut shared_value = self.shared_value.lock().unwrap();
503 *shared_value = shared_value.wrapping_add(1);
504 }
505
506 fn local_value(&self) -> usize {
507 self.local_value.get()
508 }
509
510 fn shared_value(&self) -> usize {
511 *self.shared_value.lock().unwrap()
512 }
513 }
514
515 #[test]
516 fn per_thread_smoke_test() {
517 let linked_cache = InstancePerThread::new(TokenCache::new());
518
519 let cache1 = linked_cache.acquire();
520 cache1.increment();
521
522 assert_eq!(cache1.local_value(), 1);
523 assert_eq!(cache1.shared_value(), 1);
524
525 // This must refer to the same instance.
526 let cache2 = linked_cache.acquire();
527
528 assert_eq!(cache2.local_value(), 1);
529 assert_eq!(cache2.shared_value(), 1);
530
531 cache2.increment();
532
533 assert_eq!(cache1.local_value(), 2);
534 assert_eq!(cache1.shared_value(), 2);
535
536 thread::spawn(move || {
537 // You can move InstancePerThread across threads.
538 let cache3 = linked_cache.acquire();
539
540 // This is a different thread's instance, so the local value is fresh.
541 assert_eq!(cache3.local_value(), 0);
542 assert_eq!(cache3.shared_value(), 2);
543
544 cache3.increment();
545
546 assert_eq!(cache3.local_value(), 1);
547 assert_eq!(cache3.shared_value(), 3);
548
549 // You can clone this and every clone works the same.
550 let thread_local_clone = linked_cache.clone();
551
552 let cache4 = thread_local_clone.acquire();
553
554 assert_eq!(cache4.local_value(), 1);
555 assert_eq!(cache4.shared_value(), 3);
556
557 // Every InstancePerThread instance from the same family is equivalent.
558 let cache5 = linked_cache.acquire();
559
560 assert_eq!(cache5.local_value(), 1);
561 assert_eq!(cache5.shared_value(), 3);
562
563 thread::spawn(move || {
564 let cache6 = thread_local_clone.acquire();
565
566 // This is a different thread's instance, so the local value is fresh.
567 assert_eq!(cache6.local_value(), 0);
568 assert_eq!(cache6.shared_value(), 3);
569
570 cache6.increment();
571
572 assert_eq!(cache6.local_value(), 1);
573 assert_eq!(cache6.shared_value(), 4);
574 })
575 .join()
576 .unwrap();
577 })
578 .join()
579 .unwrap();
580
581 assert_eq!(cache1.local_value(), 2);
582 assert_eq!(cache1.shared_value(), 4);
583 }
584
585 #[test]
586 fn thread_state_dropped_on_last_thread_local_drop() {
587 let linked_cache = InstancePerThread::new(TokenCache::new());
588
589 let cache = linked_cache.acquire();
590 cache.increment();
591
592 assert_eq!(cache.local_value(), 1);
593
594 // This will drop the local state.
595 drop(cache);
596
597 // We get a fresh instance now, initialized from scratch for this thread.
598 let cache = linked_cache.acquire();
599 assert_eq!(cache.local_value(), 0);
600 }
601
602 #[test]
603 fn thread_state_dropped_on_thread_exit() {
604 // At the start, no thread-specific state has been created. The link embedded into the
605 // InstancePerThread holds one reference to the inner shared value of the TokenCache.
606 let linked_cache = InstancePerThread::new(TokenCache::new());
607
608 let cache = linked_cache.acquire();
609
610 // We now have two references to the inner shared value - the link + this fn.
611 assert_eq!(Arc::strong_count(&cache.shared_value), 2);
612
613 thread::spawn(move || {
614 let cache = linked_cache.acquire();
615
616 assert_eq!(Arc::strong_count(&cache.shared_value), 3);
617 })
618 .join()
619 .unwrap();
620
621 // Should be back to 2 here - the thread-local state was dropped when the thread exited.
622 assert_eq!(Arc::strong_count(&cache.shared_value), 2);
623 }
624}