use dashmap::DashMap;
use linera_base::identifiers::ChainId;
use std::{
fmt::{self, Debug, Formatter},
sync::{Arc, Weak},
};
use tokio::sync::{Mutex, OwnedMutexGuard};
#[cfg(test)]
#[path = "unit_tests/chain_guards.rs"]
mod unit_tests;
type ChainGuardMap = DashMap<ChainId, Weak<Mutex<()>>>;
#[derive(Clone, Debug, Default)]
pub struct ChainGuards {
guards: Arc<ChainGuardMap>,
}
impl ChainGuards {
pub async fn guard(&self, chain_id: ChainId) -> ChainGuard {
let guard = self.get_or_create_lock(chain_id);
ChainGuard {
chain_id,
guards: self.guards.clone(),
guard: Some(guard.lock_owned().await),
}
}
fn get_or_create_lock(&self, chain_id: ChainId) -> Arc<Mutex<()>> {
let mut new_guard_holder = None;
let mut guard_reference = self.guards.entry(chain_id).or_insert_with(|| {
let (new_guard, weak_reference) = Self::create_new_mutex();
new_guard_holder = Some(new_guard);
weak_reference
});
guard_reference.upgrade().unwrap_or_else(|| {
let (new_guard, weak_reference) = Self::create_new_mutex();
*guard_reference = weak_reference;
new_guard
})
}
fn create_new_mutex() -> (Arc<Mutex<()>>, Weak<Mutex<()>>) {
let new_guard = Arc::new(Mutex::new(()));
let weak_reference = Arc::downgrade(&new_guard);
(new_guard, weak_reference)
}
#[cfg(any(test))]
pub(crate) fn active_guards(&self) -> usize {
self.guards.len()
}
}
pub struct ChainGuard {
chain_id: ChainId,
guards: Arc<ChainGuardMap>,
guard: Option<OwnedMutexGuard<()>>,
}
impl Drop for ChainGuard {
fn drop(&mut self) {
self.guards.remove_if(&self.chain_id, |_, _| {
let mutex = Arc::downgrade(OwnedMutexGuard::mutex(
&self
.guard
.take()
.expect("Guard dropped before `Drop` implementation"),
));
mutex.upgrade().is_none()
});
}
}
impl Debug for ChainGuard {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
formatter
.debug_struct("ChainGuard")
.field("chain_id", &self.chain_id)
.finish_non_exhaustive()
}
}
#[cfg(test)]
impl ChainGuard {
pub async fn dummy() -> Self {
let guards = ChainGuards::default();
let chain_id = ChainId::root(0);
guards.guard(chain_id).await
}
}