use std::collections::HashMap;
use std::error::Error;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
use std::thread;
use crate::blocking_guards::{InaccessibleGuardBlocking, MutGuardBlocking, RefGuardBlocking};
use crate::cell::GdCellInner;
pub struct GdCellBlocking<T> {
inner: Pin<Box<GdCellInner<T>>>,
thread_tracker: Arc<Mutex<ThreadTracker>>,
immut_condition: Arc<Condvar>,
mut_condition: Arc<Condvar>,
}
impl<T> GdCellBlocking<T> {
pub fn new(value: T) -> Self {
Self {
inner: GdCellInner::new(value),
thread_tracker: Arc::default(),
immut_condition: Arc::new(Condvar::new()),
mut_condition: Arc::new(Condvar::new()),
}
}
pub fn borrow(&self) -> Result<RefGuardBlocking<'_, T>, Box<dyn Error>> {
let mut tracker_guard = self.thread_tracker.lock().unwrap();
if self.inner.as_ref().is_currently_mutably_bound()
&& !tracker_guard.current_thread_has_mut_ref()
{
tracker_guard = self.block_immut(tracker_guard);
}
let should_claim_mut = !self.is_currently_bound();
let inner_guard = self.inner.as_ref().borrow()?;
tracker_guard.increment_current_thread_shared_count();
if should_claim_mut {
tracker_guard.claim_mut_ref();
}
Ok(RefGuardBlocking::new(
inner_guard,
self.mut_condition.clone(),
self.thread_tracker.clone(),
))
}
pub fn borrow_mut(&self) -> Result<MutGuardBlocking<'_, T>, Box<dyn Error>> {
let mut tracker_guard = self.thread_tracker.lock().unwrap();
if self.inner.as_ref().is_currently_bound()
&& tracker_guard.current_thread_shared_count() == 0
&& !tracker_guard.current_thread_has_mut_ref()
{
tracker_guard = self.block_mut(tracker_guard);
}
let inner_guard = self.inner.as_ref().borrow_mut()?;
tracker_guard.claim_mut_ref();
Ok(MutGuardBlocking::new(
inner_guard,
self.mut_condition.clone(),
self.immut_condition.clone(),
self.thread_tracker.clone(),
))
}
pub fn make_inaccessible<'cell, 'val>(
&'cell self,
current_ref: &'val mut T,
) -> Result<InaccessibleGuardBlocking<'val, T>, Box<dyn Error>>
where
'cell: 'val,
{
let _tracker_guard = self.thread_tracker.lock().unwrap();
let inner = self.inner.as_ref().make_inaccessible(current_ref)?;
let inaccessible = InaccessibleGuardBlocking::new(inner, self.thread_tracker.clone());
Ok(inaccessible)
}
pub fn is_currently_bound(&self) -> bool {
self.inner.as_ref().is_currently_bound()
}
fn block_mut<'a>(
&self,
mut tracker_guard: MutexGuard<'a, ThreadTracker>,
) -> MutexGuard<'a, ThreadTracker> {
while self.inner.as_ref().is_currently_bound() {
tracker_guard = self.mut_condition.wait(tracker_guard).unwrap();
}
tracker_guard
}
fn block_immut<'a>(
&self,
mut tracker_guard: MutexGuard<'a, ThreadTracker>,
) -> MutexGuard<'a, ThreadTracker> {
while self.inner.as_ref().is_currently_mutably_bound() {
tracker_guard = self.immut_condition.wait(tracker_guard).unwrap();
}
tracker_guard
}
}
unsafe impl<T: Send> Sync for GdCellBlocking<T> {}
#[derive(Debug)]
pub(crate) struct ThreadTracker {
mut_thread: thread::ThreadId,
shared_counts: HashMap<thread::ThreadId, usize>,
}
impl Default for ThreadTracker {
fn default() -> Self {
Self {
mut_thread: thread::current().id(),
shared_counts: HashMap::new(),
}
}
}
impl ThreadTracker {
pub fn current_thread_shared_count(&self) -> usize {
*self
.shared_counts
.get(&thread::current().id())
.unwrap_or(&0)
}
pub fn increment_current_thread_shared_count(&mut self) {
self.shared_counts
.entry(thread::current().id())
.and_modify(|count| *count += 1)
.or_insert(1);
}
pub fn decrement_current_thread_shared_count(&mut self) {
let thread_id = thread::current().id();
let entry = self.shared_counts.get_mut(&thread_id);
debug_assert!(
entry.is_some(),
"No shared reference count exists for {thread_id:?}."
);
let Some(count) = entry else {
return;
};
*count -= 1;
}
pub fn current_thread_has_mut_ref(&self) -> bool {
self.mut_thread == thread::current().id()
}
fn claim_mut_ref(&mut self) {
self.mut_thread = thread::current().id();
}
}