linked/per_thread.rs
1use std::{
2 collections::{HashMap, hash_map},
3 ops::Deref,
4 rc::Rc,
5 sync::{Arc, RwLock},
6 thread::{self, ThreadId},
7};
8
9use simple_mermaid::mermaid;
10
11use crate::{BuildThreadIdHasher, ERR_POISONED_LOCK};
12
13/// A wrapper that manages instances of linked objects of type `T`, ensuring that only one
14/// instance of `T` is created per thread.
15///
16/// This is a conceptual equivalent of the [`linked::instance_per_thread!` macro][1], with the main
17/// difference being that this type operates entirely at runtime using dynamic storage and does
18/// not require a static variable to be defined.
19///
20/// # Usage
21///
22/// Create an instance of `PerThread` and provide it an initial instance of a linked object `T`.
23/// This initial instance will be used to create additional instances on demand. Any instance
24/// of `T` retrieved through the same `PerThread` or a clone will be linked to the same family
25/// of `T` instances.
26///
27#[ doc=mermaid!( "../doc/per_thread.mermaid") ]
28///
29/// To access the current thread's instance of `T`, you must first obtain a
30/// [`ThreadLocal<T>`][ThreadLocal] which works in a manner similar to `Rc<T>`, allowing you to
31/// reference the value within. You can obtain a [`ThreadLocal<T>`][ThreadLocal] by calling
32/// [`PerThread::local()`][Self::local].
33///
34/// Once you have a [`ThreadLocal<T>`][ThreadLocal], you can access the `T` within by simply
35/// dereferencing via the `Deref<Target = T>` trait.
36///
37/// # Long-lived thread-specific instances
38///
39/// Note that the `ThreadLocal` type is `!Send`, which means you cannot store it in places that
40/// need to be thread-mobile. For example, in web framework request handlers the compiler might
41/// not permit you to let a `ThreadLocal` live across an `await`, depending on the web framework,
42/// the async task runtime used and its specific configuration.
43///
44/// # Resource management
45///
46/// A thread-specific instance of `T` is dropped when the last `ThreadLocal` on that thread is
47/// dropped. If a new `ThreadLocal` is later obtained, it is initialized with a new instance
48/// of the linked object.
49///
50/// It is important to emphasize that this means if you only create temporary `ThreadLocal`
51/// instances then you will get a new instance of `T` every time. The performance impact of
52/// this depends on how `T` works internally but you are recommended to keep `ThreadLocal`
53/// instances around for reuse when possible.
54///
55/// # Advanced scenarios
56///
57/// Use of `PerThread` does not close the door on other ways to use linked objects.
58/// For example, you always have the possibility of manually taking the `T` and creating
59/// additional clones of it to break out the one-per-thread limitation. The `PerThread` type
60/// only controls what happens through the `PerThread` type.
61///
62/// [1]: crate::instance_per_thread
63#[derive(Debug)]
64pub struct PerThread<T>
65where
66 T: linked::Object,
67{
68 family: FamilyStateReference<T>,
69}
70
71impl<T> PerThread<T>
72where
73 T: linked::Object,
74{
75 /// Creates a new `PerThread` with an existing instance of `T`. Any further access to the `T`
76 /// via the `PerThread` (or its clones) will return instances of `T` from the same family.
77 #[expect(
78 clippy::needless_pass_by_value,
79 reason = "intentional needless consume to encourage all access to go via ThreadLocal<T>"
80 )]
81 #[must_use]
82 pub fn new(inner: T) -> Self {
83 let family = FamilyStateReference::new(inner.handle());
84
85 Self { family }
86 }
87
88 /// Returns a `ThreadLocal<T>` that can be used to efficiently access the current
89 /// thread's `T` instance.
90 ///
91 /// # Example
92 ///
93 /// ```
94 /// # use std::cell::Cell;
95 /// #
96 /// # #[linked::object]
97 /// # struct Thing {
98 /// # local_value: Cell<usize>,
99 /// # }
100 /// #
101 /// # impl Thing {
102 /// # pub fn new() -> Self {
103 /// # linked::new!(Self { local_value: Cell::new(0) })
104 /// # }
105 /// #
106 /// # pub fn increment(&self) {
107 /// # self.local_value.set(self.local_value.get() + 1);
108 /// # }
109 /// #
110 /// # pub fn local_value(&self) -> usize {
111 /// # self.local_value.get()
112 /// # }
113 /// # }
114 /// #
115 /// let per_thread_thing = linked::PerThread::new(Thing::new());
116 ///
117 /// let local_thing = per_thread_thing.local();
118 /// local_thing.increment();
119 /// assert_eq!(local_thing.local_value(), 1);
120 /// ```
121 ///
122 /// # Efficiency
123 ///
124 /// Reuse the returned instance as much as possible. Every call to this function has some
125 /// overhead, especially if there are no other `ThreadLocal<T>` instances from the same family
126 /// active on the current thread.
127 ///
128 /// # Instance lifecycle
129 ///
130 /// A thread-specific instance of `T` is dropped when the last `ThreadLocal` on that thread is
131 /// dropped. If a new `ThreadLocal` is later obtained, it is initialized with a new instance
132 /// of the linked object.
133 ///
134 /// ```
135 /// # use std::cell::Cell;
136 /// #
137 /// # #[linked::object]
138 /// # struct Thing {
139 /// # local_value: Cell<usize>,
140 /// # }
141 /// #
142 /// # impl Thing {
143 /// # pub fn new() -> Self {
144 /// # linked::new!(Self { local_value: Cell::new(0) })
145 /// # }
146 /// #
147 /// # pub fn increment(&self) {
148 /// # self.local_value.set(self.local_value.get() + 1);
149 /// # }
150 /// #
151 /// # pub fn local_value(&self) -> usize {
152 /// # self.local_value.get()
153 /// # }
154 /// # }
155 /// #
156 /// let per_thread_thing = linked::PerThread::new(Thing::new());
157 ///
158 /// let local_thing = per_thread_thing.local();
159 /// local_thing.increment();
160 /// assert_eq!(local_thing.local_value(), 1);
161 ///
162 /// drop(local_thing);
163 ///
164 /// // Dropping the only thread-local instance above will have reset the thread-local state.
165 /// let local_thing = per_thread_thing.local();
166 /// assert_eq!(local_thing.local_value(), 0);
167 /// ```
168 ///
169 /// To minimize the effort spent on re-creating the thread-local state, ensure that you reuse
170 /// the `ThreadLocal<T>` instances as much as possible.
171 ///
172 /// # Thread safety
173 ///
174 /// The returned value is single-threaded and cannot be moved or used across threads. For
175 /// transfer across threads, you need to preserve and share/send a `PerThread<T>` instance.
176 #[must_use]
177 pub fn local(&self) -> ThreadLocal<T> {
178 let inner = self.family.current_thread_instance();
179
180 ThreadLocal {
181 inner,
182 family: self.family.clone(),
183 }
184 }
185}
186
187impl<T> Clone for PerThread<T>
188where
189 T: linked::Object,
190{
191 #[inline]
192 fn clone(&self) -> Self {
193 Self {
194 family: self.family.clone(),
195 }
196 }
197}
198
199/// A thread-local instance of a linked object of type `T`. This acts in a manner similar to
200/// `Rc<T>` for a type `T` that implements the [linked object pattern][crate].
201///
202/// For details, see [`PerThread<T>`][PerThread] which is the type used to create instances
203/// of `ThreadLocal<T>`.
204#[derive(Debug)]
205pub struct ThreadLocal<T>
206where
207 T: linked::Object,
208{
209 // We really are just a wrapper around an Rc<T>. The only other duty we have
210 // is to clean up the thread-local instance when the last ThreadLocal is dropped.
211 inner: Rc<T>,
212 family: FamilyStateReference<T>,
213}
214
215impl<T> Deref for ThreadLocal<T>
216where
217 T: linked::Object,
218{
219 type Target = T;
220
221 #[inline]
222 fn deref(&self) -> &Self::Target {
223 &self.inner
224 }
225}
226
227impl<T> Clone for ThreadLocal<T>
228where
229 T: linked::Object,
230{
231 #[inline]
232 fn clone(&self) -> Self {
233 Self {
234 inner: Rc::clone(&self.inner),
235 family: self.family.clone(),
236 }
237 }
238}
239
240impl<T> Drop for ThreadLocal<T>
241where
242 T: linked::Object,
243{
244 fn drop(&mut self) {
245 // If we were the last ThreadLocal on this thread then we need to drop the thread-local
246 // state for this thread. Note that there are 2 references - ourselves and the family state.
247 if Rc::strong_count(&self.inner) != 2 {
248 // No - there is another ThreadLocal, so we do not need to clean up.
249 return;
250 }
251
252 self.family.clear_current_thread_instance();
253
254 // `self.inner` is now the last reference to the current thread's instance of T
255 // and this instance will be dropped once this function returns and drops the last `Rc<T>`.
256 }
257}
258
259/// One reference to the state of a specific family of per-thread linked objects.
260/// This can be used to retrieve and/or initialize the current thread's instance.
261#[derive(Debug)]
262struct FamilyStateReference<T>
263where
264 T: linked::Object,
265{
266 // If a thread needs a new instance, we create it via this handle.
267 handle: linked::Handle<T>,
268
269 // We store the state of each thread here. See safety comments on ThreadSpecificState!
270 // NB! While it is legal to manipulate the HashMap from any thread, including to move
271 // the values, calling actual functions on a value is only valid from the thread in the key.
272 //
273 // To ensure safety, we must also ensure that all values are removed from here before the map
274 // is dropped, because each value must be dropped on the thread that created it and dropping is
275 // logic executed on that thread-specific value!
276 //
277 // This is done in the `ThreadLocal` destructor. By the time this map is dropped, it must be
278 // empty, which we assert in our own drop().
279 //
280 // The write lock here is only held when initializing the thread-specific state for a thread
281 // for the first time, which should generally be rare, especially as user code will also be
282 // motivated to reduce those instances because it also means initializing the actual `T` inside.
283 // Most access will therefore only need to take a read lock.
284 thread_specific: Arc<RwLock<HashMap<ThreadId, ThreadSpecificState<T>, BuildThreadIdHasher>>>,
285}
286
287impl<T> FamilyStateReference<T>
288where
289 T: linked::Object,
290{
291 #[must_use]
292 fn new(handle: linked::Handle<T>) -> Self {
293 Self {
294 handle,
295 thread_specific: Arc::new(RwLock::new(HashMap::with_hasher(BuildThreadIdHasher))),
296 }
297 }
298
299 /// Returns the `Rc<T>` for the current thread, creating it if necessary.
300 #[must_use]
301 fn current_thread_instance(&self) -> Rc<T> {
302 let thread_id = thread::current().id();
303
304 // First, an optimistic pass - let's assume it is already initialized for our thread.
305 {
306 let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
307
308 if let Some(state) = map.get(&thread_id) {
309 // SAFETY: We must guarantee that we are on the thread that owns
310 // the thread-specific state. We are - thread ID lookup led us here.
311 return unsafe { state.clone_instance() };
312 }
313 }
314
315 // The state for the current thread is not yet initialized. Let's initialize!
316 // Note that we create this instance outside any locks, both to reduce the
317 // lock durations but also because cloning a linked object may execute arbitrary code,
318 // including potentially code that tries to grab the same lock.
319 let instance: Rc<T> = Rc::new(self.handle.clone().into());
320
321 // Let's add the new instance to the map.
322 let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
323
324 // In some wild corner cases, it is perhaps possible that the arbitrary code in the
325 // linked object clone logic may already have filled the map with our value? It is
326 // a bit of a stretch of imagination but let's accept the possibility to be thorough.
327 match map.entry(thread_id) {
328 hash_map::Entry::Occupied(occupied_entry) => {
329 // There already is something in the entry. That's fine, we just ignore the
330 // new instance we created and pretend we are on the optimistic path.
331 let state = occupied_entry.get();
332
333 // SAFETY: We must guarantee that we are on the thread that owns
334 // the thread-specific state. We are - thread ID lookup led us here.
335 unsafe { state.clone_instance() }
336 }
337 hash_map::Entry::Vacant(vacant_entry) => {
338 // We are the first thread to create an instance. Let's insert it.
339 // SAFETY: We must guarantee that any further access (taking the Rc or dropping)
340 // takes place on the same thread as was used to call this function. We ensure this
341 // by the thread ID lookup in the map key - we can only ever directly access map
342 // entries owned by the current thread (though we may resize the map from any
343 // thread, as it simply moves data in memory).
344 let state = unsafe { ThreadSpecificState::new(Rc::clone(&instance)) };
345 vacant_entry.insert(state);
346
347 instance
348 }
349 }
350 }
351
352 fn clear_current_thread_instance(&self) {
353 // We need to clear the thread-specific state for this thread.
354 let thread_id = thread::current().id();
355
356 let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
357 map.remove(&thread_id);
358 }
359}
360
361impl<T> Clone for FamilyStateReference<T>
362where
363 T: linked::Object,
364{
365 fn clone(&self) -> Self {
366 Self {
367 handle: self.handle.clone(),
368 thread_specific: Arc::clone(&self.thread_specific),
369 }
370 }
371}
372
373impl<T> Drop for FamilyStateReference<T>
374where
375 T: linked::Object,
376{
377 #[cfg_attr(test, mutants::skip)] // This is just a sanity check, no functional behavior.
378 fn drop(&mut self) {
379 // If we are the last reference to the family state, this will drop the thread-specific map.
380 // We need to ensure that the thread-specific state is empty before we drop the map.
381 // This is a sanity check - if this fails, we have a defect somewhere in our code.
382
383 if Arc::strong_count(&self.thread_specific) > 1 {
384 // We are not the last reference to the family state,
385 // so no state dropping will occur - having state in the map is fine.
386 return;
387 }
388
389 let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
390 assert!(
391 map.is_empty(),
392 "thread-specific state map was not empty on drop - internal logic error"
393 );
394 }
395}
396
397/// Holds the thread-specific state for a specific family of per-thread linked objects.
398///
399/// # Safety
400///
401/// This contains an `Rc`, which is `!Send` and only meant to be accessed from the thread it was
402/// created on. Yet the instance of this type itself is visible from multiple threads and
403/// potentially even touched (moved) from another thread when resizing the `HashMap` of all
404/// instances! How can this be?!
405///
406/// We take advantage of the fact that an `Rc` is merely a reference to a control block.
407/// As long as we never touch the control block from the wrong thread, nobody will ever
408/// know we touched the `Rc` from another thread. This allows us to move the Rc around
409/// in memory as long as the move itself is synchronized.
410///
411/// Obviously, this relies on `Rc` implementation details, so we are somewhat at risk of
412/// breakage if a future Rust std implementation changes the way `Rc` works but this seems
413/// unlikely as this is fairly fundamental to the nature of how smart pointers are created.
414///
415/// NB! We must not drop the Rc (and by extension this type) from a foreign thread!
416#[derive(Debug)]
417struct ThreadSpecificState<T>
418where
419 T: linked::Object,
420{
421 instance: Rc<T>,
422}
423
424impl<T> ThreadSpecificState<T>
425where
426 T: linked::Object,
427{
428 /// Creates a new `ThreadSpecificState` with the given `Rc<T>`.
429 ///
430 /// # Safety
431 ///
432 /// The caller must guarantee that any further access (including dropping) takes place on the
433 /// same thread as was used to call this function.
434 ///
435 /// See type-level safety comments for details.
436 #[must_use]
437 unsafe fn new(instance: Rc<T>) -> Self {
438 Self { instance }
439 }
440
441 /// Returns the `Rc<T>` for this thread.
442 ///
443 /// # Safety
444 ///
445 /// The caller must guarantee that the current thread is the thread for which this
446 /// `ThreadSpecificState` was created. This is not enforced by the type system.
447 ///
448 /// See type-level safety comments for details.
449 #[must_use]
450 unsafe fn clone_instance(&self) -> Rc<T> {
451 Rc::clone(&self.instance)
452 }
453}
454
455// SAFETY: See comments on type.
456unsafe impl<T> Sync for ThreadSpecificState<T> where T: linked::Object {}
457// SAFETY: See comments on type.
458unsafe impl<T> Send for ThreadSpecificState<T> where T: linked::Object {}
459
460#[cfg(test)]
461mod tests {
462 use std::{
463 cell::Cell,
464 sync::{Arc, Mutex},
465 thread,
466 };
467
468 use super::*;
469
470 #[linked::object]
471 struct TokenCache {
472 shared_value: Arc<Mutex<usize>>,
473 local_value: Cell<usize>,
474 }
475
476 impl TokenCache {
477 fn new() -> Self {
478 #[expect(
479 clippy::mutex_atomic,
480 reason = "inner type is placeholder, for realistic usage"
481 )]
482 let shared_value = Arc::new(Mutex::new(0));
483
484 linked::new!(Self {
485 shared_value: Arc::clone(&shared_value),
486 local_value: Cell::new(0),
487 })
488 }
489
490 fn increment(&self) {
491 self.local_value.set(self.local_value.get().wrapping_add(1));
492
493 let mut shared_value = self.shared_value.lock().unwrap();
494 *shared_value = shared_value.wrapping_add(1);
495 }
496
497 fn local_value(&self) -> usize {
498 self.local_value.get()
499 }
500
501 fn shared_value(&self) -> usize {
502 *self.shared_value.lock().unwrap()
503 }
504 }
505
506 #[test]
507 fn per_thread_smoke_test() {
508 let per_thread = PerThread::new(TokenCache::new());
509
510 let thread_local1 = per_thread.local();
511 thread_local1.increment();
512
513 assert_eq!(thread_local1.local_value(), 1);
514 assert_eq!(thread_local1.shared_value(), 1);
515
516 // This must refer to the same instance.
517 let thread_local2 = per_thread.local();
518
519 assert_eq!(thread_local2.local_value(), 1);
520 assert_eq!(thread_local2.shared_value(), 1);
521
522 thread_local2.increment();
523
524 assert_eq!(thread_local1.local_value(), 2);
525 assert_eq!(thread_local1.shared_value(), 2);
526
527 thread::spawn(move || {
528 // You can move PerThread across threads.
529 let thread_local3 = per_thread.local();
530
531 // This is a different thread's instance, so the local value is fresh.
532 assert_eq!(thread_local3.local_value(), 0);
533 assert_eq!(thread_local3.shared_value(), 2);
534
535 thread_local3.increment();
536
537 assert_eq!(thread_local3.local_value(), 1);
538 assert_eq!(thread_local3.shared_value(), 3);
539
540 // You can clone this and every clone works the same.
541 let per_thread_clone = per_thread.clone();
542
543 let thread_local4 = per_thread_clone.local();
544
545 assert_eq!(thread_local4.local_value(), 1);
546 assert_eq!(thread_local4.shared_value(), 3);
547
548 // Every PerThread instance from the same family is equivalent.
549 let thread_local5 = per_thread.local();
550
551 assert_eq!(thread_local5.local_value(), 1);
552 assert_eq!(thread_local5.shared_value(), 3);
553
554 thread::spawn(move || {
555 let thread_local5 = per_thread_clone.local();
556
557 // This is a different thread's instance, so the local value is fresh.
558 assert_eq!(thread_local5.local_value(), 0);
559 assert_eq!(thread_local5.shared_value(), 3);
560
561 thread_local5.increment();
562
563 assert_eq!(thread_local5.local_value(), 1);
564 assert_eq!(thread_local5.shared_value(), 4);
565 })
566 .join()
567 .unwrap();
568 })
569 .join()
570 .unwrap();
571
572 assert_eq!(thread_local1.local_value(), 2);
573 assert_eq!(thread_local1.shared_value(), 4);
574 }
575
576 #[test]
577 fn thread_state_dropped_on_last_thread_local_drop() {
578 let per_thread = PerThread::new(TokenCache::new());
579
580 let local = per_thread.local();
581 local.increment();
582
583 assert_eq!(local.local_value(), 1);
584
585 // This will drop the local state.
586 drop(local);
587
588 // We get a fresh instance now, initialized from scratch for this thread.
589 let local = per_thread.local();
590 assert_eq!(local.local_value(), 0);
591 }
592
593 #[test]
594 fn thread_state_dropped_on_thread_exit() {
595 // At the start, no thread-specific state has been created. The link embedded into the
596 // PerThread holds one reference to the inner shared value of the TokenCache.
597 let per_thread = PerThread::new(TokenCache::new());
598
599 let local = per_thread.local();
600
601 // We now have two references to the inner shared value - the link + this fn.
602 assert_eq!(Arc::strong_count(&local.shared_value), 2);
603
604 thread::spawn(move || {
605 let local = per_thread.local();
606
607 assert_eq!(Arc::strong_count(&local.shared_value), 3);
608 })
609 .join()
610 .unwrap();
611
612 // Should be back to 2 here - the thread-local state was dropped when the thread exited.
613 assert_eq!(Arc::strong_count(&local.shared_value), 2);
614 }
615}