async_inspect/sync/
semaphore.rs

1//! Tracked Semaphore implementation
2//!
3//! A drop-in replacement for `tokio::sync::Semaphore` that automatically tracks
4//! permit acquisition and integrates with async-inspect's deadlock detection.
5
6use crate::deadlock::{DeadlockDetector, ResourceId, ResourceInfo, ResourceKind};
7use crate::inspector::Inspector;
8use crate::instrument::current_task_id;
9use crate::sync::{LockMetrics, MetricsTracker, WaitTimer};
10
11use std::fmt;
12use std::sync::Arc;
13use tokio::sync::Semaphore as TokioSemaphore;
14
15/// A tracked semaphore that automatically records acquisition metrics
16/// and integrates with deadlock detection.
17///
18/// This is a drop-in replacement for `tokio::sync::Semaphore` with additional
19/// observability features.
20///
21/// # Example
22///
23/// ```rust,no_run
24/// use async_inspect::sync::Semaphore;
25///
26/// #[tokio::main]
27/// async fn main() {
28///     // Allow up to 3 concurrent operations
29///     let semaphore = Semaphore::new(3, "connection_pool");
30///     let semaphore = std::sync::Arc::new(semaphore);
31///
32///     let mut handles = vec![];
33///
34///     for i in 0..10 {
35///         let sem = semaphore.clone();
36///         handles.push(tokio::spawn(async move {
37///             let _permit = sem.acquire().await.unwrap();
38///             println!("Task {} acquired permit", i);
39///             tokio::time::sleep(std::time::Duration::from_millis(100)).await;
40///         }));
41///     }
42///
43///     for h in handles {
44///         h.await.unwrap();
45///     }
46///
47///     // Check acquisition metrics
48///     let metrics = semaphore.metrics();
49///     println!("Total acquisitions: {}", metrics.acquisitions);
50///     println!("Contentions: {}", metrics.contentions);
51/// }
52/// ```
53pub struct Semaphore {
54    /// The underlying Tokio semaphore
55    inner: TokioSemaphore,
56    /// Name for debugging/display
57    name: String,
58    /// Resource ID for deadlock detection
59    resource_id: ResourceId,
60    /// Acquisition metrics
61    metrics: Arc<MetricsTracker>,
62    /// Initial permit count (for debugging)
63    initial_permits: usize,
64}
65
66impl Semaphore {
67    /// Create a new tracked semaphore with the given number of permits.
68    ///
69    /// # Arguments
70    ///
71    /// * `permits` - The number of permits available
72    /// * `name` - A descriptive name for debugging and metrics
73    ///
74    /// # Example
75    ///
76    /// ```rust,no_run
77    /// use async_inspect::sync::Semaphore;
78    ///
79    /// // Create a semaphore limiting to 5 concurrent operations
80    /// let semaphore = Semaphore::new(5, "rate_limiter");
81    /// ```
82    pub fn new(permits: usize, name: impl Into<String>) -> Self {
83        let name = name.into();
84        let resource_info = ResourceInfo::new(ResourceKind::Semaphore, name.clone());
85        let resource_id = resource_info.id;
86
87        // Register with deadlock detector
88        let detector = Inspector::global().deadlock_detector();
89        let _ = detector.register_resource(resource_info);
90
91        Self {
92            inner: TokioSemaphore::new(permits),
93            name,
94            resource_id,
95            metrics: Arc::new(MetricsTracker::new()),
96            initial_permits: permits,
97        }
98    }
99
100    /// Acquire a permit, blocking until one is available.
101    ///
102    /// # Returns
103    ///
104    /// Returns `Ok(SemaphorePermit)` if a permit was acquired, or
105    /// `Err(AcquireError)` if the semaphore was closed.
106    ///
107    /// # Example
108    ///
109    /// ```rust,no_run
110    /// use async_inspect::sync::Semaphore;
111    ///
112    /// # async fn example() {
113    /// let semaphore = Semaphore::new(3, "pool");
114    /// let permit = semaphore.acquire().await.unwrap();
115    /// // ... use the resource ...
116    /// drop(permit); // Release the permit
117    /// # }
118    /// ```
119    pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> {
120        let detector = Inspector::global().deadlock_detector();
121        let task_id = current_task_id();
122
123        // Record that we're waiting for this resource
124        if let Some(tid) = task_id {
125            detector.wait_for(tid, self.resource_id);
126        }
127
128        let timer = WaitTimer::start();
129
130        // Actually acquire the permit
131        if let Ok(permit) = self.inner.acquire().await {
132            // Record metrics
133            let wait_time = timer.elapsed_if_contended();
134            self.metrics.record_acquisition(wait_time);
135
136            // Record successful acquisition
137            if let Some(tid) = task_id {
138                detector.acquire(tid, self.resource_id);
139            }
140
141            Ok(SemaphorePermit {
142                permit,
143                resource_id: self.resource_id,
144                task_id,
145                detector: detector.clone(),
146            })
147        } else {
148            // Clear wait state on error
149            if let Some(tid) = task_id {
150                detector.release(tid, self.resource_id);
151            }
152            Err(AcquireError(()))
153        }
154    }
155
156    /// Acquire multiple permits at once.
157    ///
158    /// # Arguments
159    ///
160    /// * `n` - Number of permits to acquire
161    ///
162    /// # Example
163    ///
164    /// ```rust,no_run
165    /// use async_inspect::sync::Semaphore;
166    ///
167    /// # async fn example() {
168    /// let semaphore = Semaphore::new(10, "batch_pool");
169    /// let permit = semaphore.acquire_many(5).await.unwrap();
170    /// // ... use 5 resources at once ...
171    /// # }
172    /// ```
173    pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> {
174        let detector = Inspector::global().deadlock_detector();
175        let task_id = current_task_id();
176
177        if let Some(tid) = task_id {
178            detector.wait_for(tid, self.resource_id);
179        }
180
181        let timer = WaitTimer::start();
182
183        if let Ok(permit) = self.inner.acquire_many(n).await {
184            let wait_time = timer.elapsed_if_contended();
185            self.metrics.record_acquisition(wait_time);
186
187            if let Some(tid) = task_id {
188                detector.acquire(tid, self.resource_id);
189            }
190
191            Ok(SemaphorePermit {
192                permit,
193                resource_id: self.resource_id,
194                task_id,
195                detector: detector.clone(),
196            })
197        } else {
198            if let Some(tid) = task_id {
199                detector.release(tid, self.resource_id);
200            }
201            Err(AcquireError(()))
202        }
203    }
204
205    /// Try to acquire a permit immediately.
206    ///
207    /// Returns `None` if no permits are available.
208    ///
209    /// # Example
210    ///
211    /// ```rust,no_run
212    /// use async_inspect::sync::Semaphore;
213    ///
214    /// let semaphore = Semaphore::new(1, "exclusive");
215    /// let result = semaphore.try_acquire();
216    /// if let Ok(permit) = result {
217    ///     println!("Got the permit!");
218    ///     drop(permit);
219    /// } else {
220    ///     println!("No permits available");
221    /// }
222    /// ```
223    pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> {
224        let detector = Inspector::global().deadlock_detector();
225        let task_id = current_task_id();
226
227        match self.inner.try_acquire() {
228            Ok(permit) => {
229                self.metrics.record_acquisition(None);
230
231                if let Some(tid) = task_id {
232                    detector.acquire(tid, self.resource_id);
233                }
234
235                Ok(SemaphorePermit {
236                    permit,
237                    resource_id: self.resource_id,
238                    task_id,
239                    detector: detector.clone(),
240                })
241            }
242            Err(tokio::sync::TryAcquireError::NoPermits) => Err(TryAcquireError::NoPermits),
243            Err(tokio::sync::TryAcquireError::Closed) => Err(TryAcquireError::Closed),
244        }
245    }
246
247    /// Try to acquire multiple permits immediately.
248    pub fn try_acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, TryAcquireError> {
249        let detector = Inspector::global().deadlock_detector();
250        let task_id = current_task_id();
251
252        match self.inner.try_acquire_many(n) {
253            Ok(permit) => {
254                self.metrics.record_acquisition(None);
255
256                if let Some(tid) = task_id {
257                    detector.acquire(tid, self.resource_id);
258                }
259
260                Ok(SemaphorePermit {
261                    permit,
262                    resource_id: self.resource_id,
263                    task_id,
264                    detector: detector.clone(),
265                })
266            }
267            Err(tokio::sync::TryAcquireError::NoPermits) => Err(TryAcquireError::NoPermits),
268            Err(tokio::sync::TryAcquireError::Closed) => Err(TryAcquireError::Closed),
269        }
270    }
271
272    /// Get the current number of available permits.
273    #[must_use]
274    pub fn available_permits(&self) -> usize {
275        self.inner.available_permits()
276    }
277
278    /// Add permits to the semaphore.
279    ///
280    /// # Arguments
281    ///
282    /// * `n` - Number of permits to add
283    pub fn add_permits(&self, n: usize) {
284        self.inner.add_permits(n);
285    }
286
287    /// Close the semaphore.
288    ///
289    /// All pending acquire operations will fail with an error.
290    pub fn close(&self) {
291        self.inner.close();
292    }
293
294    /// Check if the semaphore is closed.
295    #[must_use]
296    pub fn is_closed(&self) -> bool {
297        self.inner.is_closed()
298    }
299
300    /// Get the current acquisition metrics for this semaphore.
301    ///
302    /// # Example
303    ///
304    /// ```rust,no_run
305    /// use async_inspect::sync::Semaphore;
306    ///
307    /// let semaphore = Semaphore::new(5, "pool");
308    /// // ... some operations ...
309    /// let metrics = semaphore.metrics();
310    /// println!("Acquisitions: {}", metrics.acquisitions);
311    /// println!("Contention rate: {:.1}%", metrics.contention_rate() * 100.0);
312    /// ```
313    #[must_use]
314    pub fn metrics(&self) -> LockMetrics {
315        self.metrics.get_metrics()
316    }
317
318    /// Reset the acquisition metrics.
319    pub fn reset_metrics(&self) {
320        self.metrics.reset();
321    }
322
323    /// Get the name of this semaphore.
324    #[must_use]
325    pub fn name(&self) -> &str {
326        &self.name
327    }
328
329    /// Get the resource ID for deadlock detection.
330    #[must_use]
331    pub fn resource_id(&self) -> ResourceId {
332        self.resource_id
333    }
334
335    /// Get the initial number of permits.
336    #[must_use]
337    pub fn initial_permits(&self) -> usize {
338        self.initial_permits
339    }
340}
341
342impl fmt::Debug for Semaphore {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        let metrics = self.metrics();
345        f.debug_struct("Semaphore")
346            .field("name", &self.name)
347            .field("resource_id", &self.resource_id)
348            .field("initial_permits", &self.initial_permits)
349            .field("available_permits", &self.available_permits())
350            .field("acquisitions", &metrics.acquisitions)
351            .field("contentions", &metrics.contentions)
352            .finish()
353    }
354}
355
356/// Error returned when acquiring a permit fails because the semaphore is closed.
357#[derive(Debug, Clone, PartialEq, Eq)]
358pub struct AcquireError(());
359
360impl fmt::Display for AcquireError {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        write!(f, "semaphore closed")
363    }
364}
365
366impl std::error::Error for AcquireError {}
367
368/// Error returned when trying to acquire a permit fails.
369#[derive(Debug, Clone, PartialEq, Eq)]
370pub enum TryAcquireError {
371    /// No permits available.
372    NoPermits,
373    /// The semaphore is closed.
374    Closed,
375}
376
377impl fmt::Display for TryAcquireError {
378    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379        match self {
380            TryAcquireError::NoPermits => write!(f, "no permits available"),
381            TryAcquireError::Closed => write!(f, "semaphore closed"),
382        }
383    }
384}
385
386impl std::error::Error for TryAcquireError {}
387
388/// RAII guard for a tracked semaphore permit.
389///
390/// When this guard is dropped, the permit is released and the deadlock
391/// detector is notified.
392pub struct SemaphorePermit<'a> {
393    permit: tokio::sync::SemaphorePermit<'a>,
394    resource_id: ResourceId,
395    task_id: Option<crate::task::TaskId>,
396    detector: DeadlockDetector,
397}
398
399impl SemaphorePermit<'_> {
400    /// Forget this permit, preventing it from being released.
401    ///
402    /// This is useful for implementing manual permit management.
403    pub fn forget(self) {
404        // Use ManuallyDrop to prevent the Drop impl from running
405        let mut this = std::mem::ManuallyDrop::new(self);
406        // Clear task_id so if Drop somehow ran, it wouldn't notify
407        this.task_id = None;
408        // SAFETY: We're taking ownership and forgetting, so this is the last access
409        let permit = unsafe { std::ptr::read(&this.permit) };
410        permit.forget();
411    }
412}
413
414impl Drop for SemaphorePermit<'_> {
415    fn drop(&mut self) {
416        if let Some(tid) = self.task_id {
417            self.detector.release(tid, self.resource_id);
418        }
419    }
420}
421
422impl fmt::Debug for SemaphorePermit<'_> {
423    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424        f.debug_struct("SemaphorePermit")
425            .field("resource_id", &self.resource_id)
426            .finish()
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[tokio::test]
435    async fn test_basic_acquire_release() {
436        let semaphore = Semaphore::new(2, "test_sem");
437
438        let permit1 = semaphore.acquire().await.unwrap();
439        assert_eq!(semaphore.available_permits(), 1);
440
441        let permit2 = semaphore.acquire().await.unwrap();
442        assert_eq!(semaphore.available_permits(), 0);
443
444        drop(permit1);
445        assert_eq!(semaphore.available_permits(), 1);
446
447        drop(permit2);
448        assert_eq!(semaphore.available_permits(), 2);
449
450        let metrics = semaphore.metrics();
451        assert_eq!(metrics.acquisitions, 2);
452    }
453
454    #[tokio::test]
455    async fn test_try_acquire() {
456        let semaphore = Semaphore::new(1, "test_sem");
457
458        let permit = semaphore.try_acquire();
459        assert!(permit.is_ok());
460
461        // Should fail - no permits available
462        let permit2 = semaphore.try_acquire();
463        assert!(matches!(permit2, Err(TryAcquireError::NoPermits)));
464
465        drop(permit);
466
467        // Should succeed now
468        let permit3 = semaphore.try_acquire();
469        assert!(permit3.is_ok());
470    }
471
472    #[tokio::test]
473    async fn test_acquire_many() {
474        let semaphore = Semaphore::new(5, "test_sem");
475
476        let permit = semaphore.acquire_many(3).await.unwrap();
477        assert_eq!(semaphore.available_permits(), 2);
478
479        drop(permit);
480        assert_eq!(semaphore.available_permits(), 5);
481    }
482
483    #[tokio::test]
484    async fn test_contention() {
485        use std::sync::Arc;
486        use tokio::time::{sleep, Duration};
487
488        let semaphore = Arc::new(Semaphore::new(1, "contended_sem"));
489        let mut handles = vec![];
490
491        for _ in 0..5 {
492            let sem = semaphore.clone();
493            handles.push(tokio::spawn(async move {
494                let _permit = sem.acquire().await.unwrap();
495                sleep(Duration::from_millis(10)).await;
496            }));
497        }
498
499        for h in handles {
500            h.await.unwrap();
501        }
502
503        let metrics = semaphore.metrics();
504        assert_eq!(metrics.acquisitions, 5);
505        // At least some contention should have occurred
506        assert!(metrics.contentions > 0);
507    }
508
509    #[tokio::test]
510    async fn test_close() {
511        let semaphore = Semaphore::new(1, "closeable");
512
513        // Take the only permit
514        let _permit = semaphore.acquire().await.unwrap();
515
516        // Close the semaphore
517        semaphore.close();
518        assert!(semaphore.is_closed());
519
520        // This should fail
521        let result = semaphore.try_acquire();
522        assert!(matches!(result, Err(TryAcquireError::Closed)));
523    }
524
525    #[tokio::test]
526    async fn test_add_permits() {
527        let semaphore = Semaphore::new(1, "expandable");
528
529        let _permit = semaphore.acquire().await.unwrap();
530        assert_eq!(semaphore.available_permits(), 0);
531
532        semaphore.add_permits(2);
533        assert_eq!(semaphore.available_permits(), 2);
534    }
535}