use std::collections::{HashMap, hash_map};
use std::ops::Deref;
use std::rc::Rc;
use std::sync::{Arc, RwLock};
use std::thread::{self, ThreadId};
use simple_mermaid::mermaid;
use crate::{BuildThreadIdHasher, ERR_POISONED_LOCK};
#[ doc=mermaid!( "../doc/instance_per_thread.mermaid") ]
#[derive(Debug)]
pub struct InstancePerThread<T>
where
T: linked::Object,
{
family: FamilyStateReference<T>,
}
impl<T> InstancePerThread<T>
where
T: linked::Object,
{
#[expect(
clippy::needless_pass_by_value,
reason = "intentional needless consume to encourage all access to go via InstancePerThread<T>"
)]
#[must_use]
pub fn new(inner: T) -> Self {
let family = FamilyStateReference::new(inner.family());
Self { family }
}
#[must_use]
pub fn acquire(&self) -> Ref<T> {
let inner = self.family.current_thread_instance();
Ref {
inner,
family: self.family.clone(),
}
}
}
impl<T> Clone for InstancePerThread<T>
where
T: linked::Object,
{
#[inline]
fn clone(&self) -> Self {
Self {
family: self.family.clone(),
}
}
}
#[derive(Debug)]
pub struct Ref<T>
where
T: linked::Object,
{
inner: Rc<T>,
family: FamilyStateReference<T>,
}
impl<T> Deref for Ref<T>
where
T: linked::Object,
{
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> Clone for Ref<T>
where
T: linked::Object,
{
#[inline]
fn clone(&self) -> Self {
Self {
inner: Rc::clone(&self.inner),
family: self.family.clone(),
}
}
}
impl<T> Drop for Ref<T>
where
T: linked::Object,
{
fn drop(&mut self) {
if Rc::strong_count(&self.inner) != 2 {
return;
}
self.family.clear_current_thread_instance();
}
}
#[derive(Debug)]
struct FamilyStateReference<T>
where
T: linked::Object,
{
family: linked::Family<T>,
thread_specific: Arc<RwLock<HashMap<ThreadId, ThreadSpecificState<T>, BuildThreadIdHasher>>>,
}
impl<T> FamilyStateReference<T>
where
T: linked::Object,
{
#[must_use]
fn new(family: linked::Family<T>) -> Self {
Self {
family,
thread_specific: Arc::new(RwLock::new(HashMap::with_hasher(BuildThreadIdHasher))),
}
}
#[must_use]
fn current_thread_instance(&self) -> Rc<T> {
let thread_id = thread::current().id();
{
let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
if let Some(state) = map.get(&thread_id) {
return unsafe { state.clone_instance() };
}
}
let instance: Rc<T> = Rc::new(self.family.clone().into());
let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
match map.entry(thread_id) {
hash_map::Entry::Occupied(occupied_entry) => {
let state = occupied_entry.get();
unsafe { state.clone_instance() }
}
hash_map::Entry::Vacant(vacant_entry) => {
let state = unsafe { ThreadSpecificState::new(Rc::clone(&instance)) };
vacant_entry.insert(state);
instance
}
}
}
fn clear_current_thread_instance(&self) {
let thread_id = thread::current().id();
let mut map = self.thread_specific.write().expect(ERR_POISONED_LOCK);
map.remove(&thread_id);
}
}
impl<T> Clone for FamilyStateReference<T>
where
T: linked::Object,
{
fn clone(&self) -> Self {
Self {
family: self.family.clone(),
thread_specific: Arc::clone(&self.thread_specific),
}
}
}
impl<T> Drop for FamilyStateReference<T>
where
T: linked::Object,
{
#[cfg_attr(test, mutants::skip)] fn drop(&mut self) {
if Arc::strong_count(&self.thread_specific) > 1 {
return;
}
if thread::panicking() {
return;
}
let map = self.thread_specific.read().expect(ERR_POISONED_LOCK);
assert!(
map.is_empty(),
"thread-specific state map was not empty on drop - internal logic error"
);
}
}
#[derive(Debug)]
struct ThreadSpecificState<T>
where
T: linked::Object,
{
instance: Rc<T>,
}
impl<T> ThreadSpecificState<T>
where
T: linked::Object,
{
#[must_use]
unsafe fn new(instance: Rc<T>) -> Self {
Self { instance }
}
#[must_use]
unsafe fn clone_instance(&self) -> Rc<T> {
Rc::clone(&self.instance)
}
}
unsafe impl<T> Sync for ThreadSpecificState<T> where T: linked::Object {}
unsafe impl<T> Send for ThreadSpecificState<T> where T: linked::Object {}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use std::cell::Cell;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::atomic::{self, AtomicBool};
use std::sync::{Arc, Mutex};
use std::thread;
use static_assertions::assert_impl_all;
use super::*;
assert_impl_all!(InstancePerThread<TokenCache>: UnwindSafe, RefUnwindSafe);
#[linked::object]
struct SimpleValue {
#[expect(
dead_code,
reason = "field exists to give the type a String \
component for static assertions"
)]
data: String,
}
assert_impl_all!(Ref<SimpleValue>: UnwindSafe, RefUnwindSafe);
#[linked::object]
struct TokenCache {
shared_value: Arc<Mutex<usize>>,
local_value: Cell<usize>,
}
impl TokenCache {
fn new() -> Self {
let shared_value = Arc::new(Mutex::new(0));
linked::new!(Self {
shared_value: Arc::clone(&shared_value),
local_value: Cell::new(0),
})
}
fn increment(&self) {
self.local_value.set(self.local_value.get().wrapping_add(1));
let mut shared_value = self.shared_value.lock().unwrap();
*shared_value = shared_value.wrapping_add(1);
}
fn local_value(&self) -> usize {
self.local_value.get()
}
fn shared_value(&self) -> usize {
*self.shared_value.lock().unwrap()
}
}
#[test]
fn per_thread_smoke_test() {
let linked_cache = InstancePerThread::new(TokenCache::new());
let cache1 = linked_cache.acquire();
cache1.increment();
assert_eq!(cache1.local_value(), 1);
assert_eq!(cache1.shared_value(), 1);
let cache2 = linked_cache.acquire();
assert_eq!(cache2.local_value(), 1);
assert_eq!(cache2.shared_value(), 1);
cache2.increment();
assert_eq!(cache1.local_value(), 2);
assert_eq!(cache1.shared_value(), 2);
thread::spawn(move || {
let cache3 = linked_cache.acquire();
assert_eq!(cache3.local_value(), 0);
assert_eq!(cache3.shared_value(), 2);
cache3.increment();
assert_eq!(cache3.local_value(), 1);
assert_eq!(cache3.shared_value(), 3);
let thread_local_clone = linked_cache.clone();
let cache4 = thread_local_clone.acquire();
assert_eq!(cache4.local_value(), 1);
assert_eq!(cache4.shared_value(), 3);
let cache5 = linked_cache.acquire();
assert_eq!(cache5.local_value(), 1);
assert_eq!(cache5.shared_value(), 3);
thread::spawn(move || {
let cache6 = thread_local_clone.acquire();
assert_eq!(cache6.local_value(), 0);
assert_eq!(cache6.shared_value(), 3);
cache6.increment();
assert_eq!(cache6.local_value(), 1);
assert_eq!(cache6.shared_value(), 4);
})
.join()
.unwrap();
})
.join()
.unwrap();
assert_eq!(cache1.local_value(), 2);
assert_eq!(cache1.shared_value(), 4);
}
#[test]
fn thread_state_dropped_on_last_thread_local_drop() {
let linked_cache = InstancePerThread::new(TokenCache::new());
let cache = linked_cache.acquire();
cache.increment();
assert_eq!(cache.local_value(), 1);
drop(cache);
let cache = linked_cache.acquire();
assert_eq!(cache.local_value(), 0);
}
#[test]
fn thread_state_dropped_on_thread_exit() {
let linked_cache = InstancePerThread::new(TokenCache::new());
let cache = linked_cache.acquire();
assert_eq!(Arc::strong_count(&cache.shared_value), 2);
thread::spawn(move || {
let cache = linked_cache.acquire();
assert_eq!(Arc::strong_count(&cache.shared_value), 3);
})
.join()
.unwrap();
assert_eq!(Arc::strong_count(&cache.shared_value), 2);
}
#[linked::object]
struct ReentrantType {}
impl ReentrantType {
fn new(
shared_ipt: Arc<Mutex<Option<InstancePerThread<Self>>>>,
reentry_flag: Arc<AtomicBool>,
) -> Self {
linked::__private::new(move |link| {
let maybe_ipt = shared_ipt.lock().unwrap().clone();
if let Some(ipt) = maybe_ipt
&& reentry_flag
.compare_exchange(
false,
true,
atomic::Ordering::Relaxed,
atomic::Ordering::Relaxed,
)
.is_ok()
{
let _inner_ref = ipt.acquire();
}
Self {
__private_linked_link: link,
}
})
}
}
#[test]
fn reentrant_acquire_hits_occupied_branch() {
let shared_ipt: Arc<Mutex<Option<InstancePerThread<ReentrantType>>>> =
Arc::new(Mutex::new(None));
let reentry_flag = Arc::new(AtomicBool::new(false));
let first = ReentrantType::new(Arc::clone(&shared_ipt), Arc::clone(&reentry_flag));
let ipt = InstancePerThread::new(first);
*shared_ipt.lock().unwrap() = Some(ipt.clone());
let _outer_ref = ipt.acquire();
assert!(reentry_flag.load(atomic::Ordering::Relaxed));
*shared_ipt.lock().unwrap() = None;
}
}