1use crate::deadlock::{DeadlockDetector, ResourceId, ResourceInfo, ResourceKind};
7use crate::inspector::Inspector;
8use crate::instrument::current_task_id;
9use crate::sync::{LockMetrics, MetricsTracker, WaitTimer};
10
11use std::fmt;
12use std::sync::Arc;
13use tokio::sync::Semaphore as TokioSemaphore;
14
15pub struct Semaphore {
54 inner: TokioSemaphore,
56 name: String,
58 resource_id: ResourceId,
60 metrics: Arc<MetricsTracker>,
62 initial_permits: usize,
64}
65
66impl Semaphore {
67 pub fn new(permits: usize, name: impl Into<String>) -> Self {
83 let name = name.into();
84 let resource_info = ResourceInfo::new(ResourceKind::Semaphore, name.clone());
85 let resource_id = resource_info.id;
86
87 let detector = Inspector::global().deadlock_detector();
89 let _ = detector.register_resource(resource_info);
90
91 Self {
92 inner: TokioSemaphore::new(permits),
93 name,
94 resource_id,
95 metrics: Arc::new(MetricsTracker::new()),
96 initial_permits: permits,
97 }
98 }
99
100 pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> {
120 let detector = Inspector::global().deadlock_detector();
121 let task_id = current_task_id();
122
123 if let Some(tid) = task_id {
125 detector.wait_for(tid, self.resource_id);
126 }
127
128 let timer = WaitTimer::start();
129
130 if let Ok(permit) = self.inner.acquire().await {
132 let wait_time = timer.elapsed_if_contended();
134 self.metrics.record_acquisition(wait_time);
135
136 if let Some(tid) = task_id {
138 detector.acquire(tid, self.resource_id);
139 }
140
141 Ok(SemaphorePermit {
142 permit,
143 resource_id: self.resource_id,
144 task_id,
145 detector: detector.clone(),
146 })
147 } else {
148 if let Some(tid) = task_id {
150 detector.release(tid, self.resource_id);
151 }
152 Err(AcquireError(()))
153 }
154 }
155
156 pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> {
174 let detector = Inspector::global().deadlock_detector();
175 let task_id = current_task_id();
176
177 if let Some(tid) = task_id {
178 detector.wait_for(tid, self.resource_id);
179 }
180
181 let timer = WaitTimer::start();
182
183 if let Ok(permit) = self.inner.acquire_many(n).await {
184 let wait_time = timer.elapsed_if_contended();
185 self.metrics.record_acquisition(wait_time);
186
187 if let Some(tid) = task_id {
188 detector.acquire(tid, self.resource_id);
189 }
190
191 Ok(SemaphorePermit {
192 permit,
193 resource_id: self.resource_id,
194 task_id,
195 detector: detector.clone(),
196 })
197 } else {
198 if let Some(tid) = task_id {
199 detector.release(tid, self.resource_id);
200 }
201 Err(AcquireError(()))
202 }
203 }
204
205 pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> {
224 let detector = Inspector::global().deadlock_detector();
225 let task_id = current_task_id();
226
227 match self.inner.try_acquire() {
228 Ok(permit) => {
229 self.metrics.record_acquisition(None);
230
231 if let Some(tid) = task_id {
232 detector.acquire(tid, self.resource_id);
233 }
234
235 Ok(SemaphorePermit {
236 permit,
237 resource_id: self.resource_id,
238 task_id,
239 detector: detector.clone(),
240 })
241 }
242 Err(tokio::sync::TryAcquireError::NoPermits) => Err(TryAcquireError::NoPermits),
243 Err(tokio::sync::TryAcquireError::Closed) => Err(TryAcquireError::Closed),
244 }
245 }
246
247 pub fn try_acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, TryAcquireError> {
249 let detector = Inspector::global().deadlock_detector();
250 let task_id = current_task_id();
251
252 match self.inner.try_acquire_many(n) {
253 Ok(permit) => {
254 self.metrics.record_acquisition(None);
255
256 if let Some(tid) = task_id {
257 detector.acquire(tid, self.resource_id);
258 }
259
260 Ok(SemaphorePermit {
261 permit,
262 resource_id: self.resource_id,
263 task_id,
264 detector: detector.clone(),
265 })
266 }
267 Err(tokio::sync::TryAcquireError::NoPermits) => Err(TryAcquireError::NoPermits),
268 Err(tokio::sync::TryAcquireError::Closed) => Err(TryAcquireError::Closed),
269 }
270 }
271
272 #[must_use]
274 pub fn available_permits(&self) -> usize {
275 self.inner.available_permits()
276 }
277
278 pub fn add_permits(&self, n: usize) {
284 self.inner.add_permits(n);
285 }
286
287 pub fn close(&self) {
291 self.inner.close();
292 }
293
294 #[must_use]
296 pub fn is_closed(&self) -> bool {
297 self.inner.is_closed()
298 }
299
300 #[must_use]
314 pub fn metrics(&self) -> LockMetrics {
315 self.metrics.get_metrics()
316 }
317
318 pub fn reset_metrics(&self) {
320 self.metrics.reset();
321 }
322
323 #[must_use]
325 pub fn name(&self) -> &str {
326 &self.name
327 }
328
329 #[must_use]
331 pub fn resource_id(&self) -> ResourceId {
332 self.resource_id
333 }
334
335 #[must_use]
337 pub fn initial_permits(&self) -> usize {
338 self.initial_permits
339 }
340}
341
342impl fmt::Debug for Semaphore {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 let metrics = self.metrics();
345 f.debug_struct("Semaphore")
346 .field("name", &self.name)
347 .field("resource_id", &self.resource_id)
348 .field("initial_permits", &self.initial_permits)
349 .field("available_permits", &self.available_permits())
350 .field("acquisitions", &metrics.acquisitions)
351 .field("contentions", &metrics.contentions)
352 .finish()
353 }
354}
355
356#[derive(Debug, Clone, PartialEq, Eq)]
358pub struct AcquireError(());
359
360impl fmt::Display for AcquireError {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 write!(f, "semaphore closed")
363 }
364}
365
366impl std::error::Error for AcquireError {}
367
368#[derive(Debug, Clone, PartialEq, Eq)]
370pub enum TryAcquireError {
371 NoPermits,
373 Closed,
375}
376
377impl fmt::Display for TryAcquireError {
378 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379 match self {
380 TryAcquireError::NoPermits => write!(f, "no permits available"),
381 TryAcquireError::Closed => write!(f, "semaphore closed"),
382 }
383 }
384}
385
386impl std::error::Error for TryAcquireError {}
387
388pub struct SemaphorePermit<'a> {
393 permit: tokio::sync::SemaphorePermit<'a>,
394 resource_id: ResourceId,
395 task_id: Option<crate::task::TaskId>,
396 detector: DeadlockDetector,
397}
398
399impl SemaphorePermit<'_> {
400 pub fn forget(self) {
404 let mut this = std::mem::ManuallyDrop::new(self);
406 this.task_id = None;
408 let permit = unsafe { std::ptr::read(&this.permit) };
410 permit.forget();
411 }
412}
413
414impl Drop for SemaphorePermit<'_> {
415 fn drop(&mut self) {
416 if let Some(tid) = self.task_id {
417 self.detector.release(tid, self.resource_id);
418 }
419 }
420}
421
422impl fmt::Debug for SemaphorePermit<'_> {
423 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424 f.debug_struct("SemaphorePermit")
425 .field("resource_id", &self.resource_id)
426 .finish()
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[tokio::test]
435 async fn test_basic_acquire_release() {
436 let semaphore = Semaphore::new(2, "test_sem");
437
438 let permit1 = semaphore.acquire().await.unwrap();
439 assert_eq!(semaphore.available_permits(), 1);
440
441 let permit2 = semaphore.acquire().await.unwrap();
442 assert_eq!(semaphore.available_permits(), 0);
443
444 drop(permit1);
445 assert_eq!(semaphore.available_permits(), 1);
446
447 drop(permit2);
448 assert_eq!(semaphore.available_permits(), 2);
449
450 let metrics = semaphore.metrics();
451 assert_eq!(metrics.acquisitions, 2);
452 }
453
454 #[tokio::test]
455 async fn test_try_acquire() {
456 let semaphore = Semaphore::new(1, "test_sem");
457
458 let permit = semaphore.try_acquire();
459 assert!(permit.is_ok());
460
461 let permit2 = semaphore.try_acquire();
463 assert!(matches!(permit2, Err(TryAcquireError::NoPermits)));
464
465 drop(permit);
466
467 let permit3 = semaphore.try_acquire();
469 assert!(permit3.is_ok());
470 }
471
472 #[tokio::test]
473 async fn test_acquire_many() {
474 let semaphore = Semaphore::new(5, "test_sem");
475
476 let permit = semaphore.acquire_many(3).await.unwrap();
477 assert_eq!(semaphore.available_permits(), 2);
478
479 drop(permit);
480 assert_eq!(semaphore.available_permits(), 5);
481 }
482
483 #[tokio::test]
484 async fn test_contention() {
485 use std::sync::Arc;
486 use tokio::time::{sleep, Duration};
487
488 let semaphore = Arc::new(Semaphore::new(1, "contended_sem"));
489 let mut handles = vec![];
490
491 for _ in 0..5 {
492 let sem = semaphore.clone();
493 handles.push(tokio::spawn(async move {
494 let _permit = sem.acquire().await.unwrap();
495 sleep(Duration::from_millis(10)).await;
496 }));
497 }
498
499 for h in handles {
500 h.await.unwrap();
501 }
502
503 let metrics = semaphore.metrics();
504 assert_eq!(metrics.acquisitions, 5);
505 assert!(metrics.contentions > 0);
507 }
508
509 #[tokio::test]
510 async fn test_close() {
511 let semaphore = Semaphore::new(1, "closeable");
512
513 let _permit = semaphore.acquire().await.unwrap();
515
516 semaphore.close();
518 assert!(semaphore.is_closed());
519
520 let result = semaphore.try_acquire();
522 assert!(matches!(result, Err(TryAcquireError::Closed)));
523 }
524
525 #[tokio::test]
526 async fn test_add_permits() {
527 let semaphore = Semaphore::new(1, "expandable");
528
529 let _permit = semaphore.acquire().await.unwrap();
530 assert_eq!(semaphore.available_permits(), 0);
531
532 semaphore.add_permits(2);
533 assert_eq!(semaphore.available_permits(), 2);
534 }
535}