use crate::deadlock::{DeadlockDetector, ResourceId, ResourceInfo, ResourceKind};
use crate::inspector::Inspector;
use crate::instrument::current_task_id;
use crate::sync::{LockMetrics, MetricsTracker, WaitTimer};
use std::fmt;
use std::sync::Arc;
use tokio::sync::Semaphore as TokioSemaphore;
pub struct Semaphore {
inner: TokioSemaphore,
name: String,
resource_id: ResourceId,
metrics: Arc<MetricsTracker>,
initial_permits: usize,
}
impl Semaphore {
pub fn new(permits: usize, name: impl Into<String>) -> Self {
let name = name.into();
let resource_info = ResourceInfo::new(ResourceKind::Semaphore, name.clone());
let resource_id = resource_info.id;
let detector = Inspector::global().deadlock_detector();
let _ = detector.register_resource(resource_info);
Self {
inner: TokioSemaphore::new(permits),
name,
resource_id,
metrics: Arc::new(MetricsTracker::new()),
initial_permits: permits,
}
}
pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
if let Some(tid) = task_id {
detector.wait_for(tid, self.resource_id);
}
let timer = WaitTimer::start();
if let Ok(permit) = self.inner.acquire().await {
let wait_time = timer.elapsed_if_contended();
self.metrics.record_acquisition(wait_time);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
Ok(SemaphorePermit {
permit,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
})
} else {
if let Some(tid) = task_id {
detector.release(tid, self.resource_id);
}
Err(AcquireError(()))
}
}
pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
if let Some(tid) = task_id {
detector.wait_for(tid, self.resource_id);
}
let timer = WaitTimer::start();
if let Ok(permit) = self.inner.acquire_many(n).await {
let wait_time = timer.elapsed_if_contended();
self.metrics.record_acquisition(wait_time);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
Ok(SemaphorePermit {
permit,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
})
} else {
if let Some(tid) = task_id {
detector.release(tid, self.resource_id);
}
Err(AcquireError(()))
}
}
pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
match self.inner.try_acquire() {
Ok(permit) => {
self.metrics.record_acquisition(None);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
Ok(SemaphorePermit {
permit,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
})
}
Err(tokio::sync::TryAcquireError::NoPermits) => Err(TryAcquireError::NoPermits),
Err(tokio::sync::TryAcquireError::Closed) => Err(TryAcquireError::Closed),
}
}
pub fn try_acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, TryAcquireError> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
match self.inner.try_acquire_many(n) {
Ok(permit) => {
self.metrics.record_acquisition(None);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
Ok(SemaphorePermit {
permit,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
})
}
Err(tokio::sync::TryAcquireError::NoPermits) => Err(TryAcquireError::NoPermits),
Err(tokio::sync::TryAcquireError::Closed) => Err(TryAcquireError::Closed),
}
}
#[must_use]
pub fn available_permits(&self) -> usize {
self.inner.available_permits()
}
pub fn add_permits(&self, n: usize) {
self.inner.add_permits(n);
}
pub fn close(&self) {
self.inner.close();
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
#[must_use]
pub fn metrics(&self) -> LockMetrics {
self.metrics.get_metrics()
}
pub fn reset_metrics(&self) {
self.metrics.reset();
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn resource_id(&self) -> ResourceId {
self.resource_id
}
#[must_use]
pub fn initial_permits(&self) -> usize {
self.initial_permits
}
}
impl fmt::Debug for Semaphore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let metrics = self.metrics();
f.debug_struct("Semaphore")
.field("name", &self.name)
.field("resource_id", &self.resource_id)
.field("initial_permits", &self.initial_permits)
.field("available_permits", &self.available_permits())
.field("acquisitions", &metrics.acquisitions)
.field("contentions", &metrics.contentions)
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AcquireError(());
impl fmt::Display for AcquireError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "semaphore closed")
}
}
impl std::error::Error for AcquireError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TryAcquireError {
NoPermits,
Closed,
}
impl fmt::Display for TryAcquireError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryAcquireError::NoPermits => write!(f, "no permits available"),
TryAcquireError::Closed => write!(f, "semaphore closed"),
}
}
}
impl std::error::Error for TryAcquireError {}
pub struct SemaphorePermit<'a> {
permit: tokio::sync::SemaphorePermit<'a>,
resource_id: ResourceId,
task_id: Option<crate::task::TaskId>,
detector: DeadlockDetector,
}
impl SemaphorePermit<'_> {
pub fn forget(self) {
let mut this = std::mem::ManuallyDrop::new(self);
this.task_id = None;
let permit = unsafe { std::ptr::read(&this.permit) };
permit.forget();
}
}
impl Drop for SemaphorePermit<'_> {
fn drop(&mut self) {
if let Some(tid) = self.task_id {
self.detector.release(tid, self.resource_id);
}
}
}
impl fmt::Debug for SemaphorePermit<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SemaphorePermit")
.field("resource_id", &self.resource_id)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_acquire_release() {
let semaphore = Semaphore::new(2, "test_sem");
let permit1 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 1);
let permit2 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
drop(permit1);
assert_eq!(semaphore.available_permits(), 1);
drop(permit2);
assert_eq!(semaphore.available_permits(), 2);
let metrics = semaphore.metrics();
assert_eq!(metrics.acquisitions, 2);
}
#[tokio::test]
async fn test_try_acquire() {
let semaphore = Semaphore::new(1, "test_sem");
let permit = semaphore.try_acquire();
assert!(permit.is_ok());
let permit2 = semaphore.try_acquire();
assert!(matches!(permit2, Err(TryAcquireError::NoPermits)));
drop(permit);
let permit3 = semaphore.try_acquire();
assert!(permit3.is_ok());
}
#[tokio::test]
async fn test_acquire_many() {
let semaphore = Semaphore::new(5, "test_sem");
let permit = semaphore.acquire_many(3).await.unwrap();
assert_eq!(semaphore.available_permits(), 2);
drop(permit);
assert_eq!(semaphore.available_permits(), 5);
}
#[tokio::test]
async fn test_contention() {
use std::sync::Arc;
use tokio::time::{sleep, Duration};
let semaphore = Arc::new(Semaphore::new(1, "contended_sem"));
let mut handles = vec![];
for _ in 0..5 {
let sem = semaphore.clone();
handles.push(tokio::spawn(async move {
let _permit = sem.acquire().await.unwrap();
sleep(Duration::from_millis(10)).await;
}));
}
for h in handles {
h.await.unwrap();
}
let metrics = semaphore.metrics();
assert_eq!(metrics.acquisitions, 5);
assert!(metrics.contentions > 0);
}
#[tokio::test]
async fn test_close() {
let semaphore = Semaphore::new(1, "closeable");
let _permit = semaphore.acquire().await.unwrap();
semaphore.close();
assert!(semaphore.is_closed());
let result = semaphore.try_acquire();
assert!(matches!(result, Err(TryAcquireError::Closed)));
}
#[tokio::test]
async fn test_add_permits() {
let semaphore = Semaphore::new(1, "expandable");
let _permit = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
semaphore.add_permits(2);
assert_eq!(semaphore.available_permits(), 2);
}
}