use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct BulkheadSemaphore {
limit: usize,
acquired: AtomicUsize,
}
impl BulkheadSemaphore {
pub fn new(limit: usize) -> Self {
assert!(limit > 0, "Bulkhead limit must be greater than 0");
Self {
limit,
acquired: AtomicUsize::new(0),
}
}
pub fn try_acquire(self: &Arc<Self>) -> Option<BulkheadGuard> {
let mut current = self.acquired.load(Ordering::Acquire);
loop {
if current >= self.limit {
return None;
}
match self.acquired.compare_exchange_weak(
current,
current + 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
return Some(BulkheadGuard {
semaphore: Arc::clone(self),
});
}
Err(actual) => {
current = actual;
}
}
}
}
pub fn acquired(&self) -> usize {
self.acquired.load(Ordering::Acquire)
}
pub fn limit(&self) -> usize {
self.limit
}
pub fn available(&self) -> usize {
self.limit.saturating_sub(self.acquired())
}
fn release(&self) {
self.acquired.fetch_sub(1, Ordering::Release);
}
}
#[derive(Debug)]
pub struct BulkheadGuard {
semaphore: Arc<BulkheadSemaphore>,
}
impl Drop for BulkheadGuard {
fn drop(&mut self) {
self.semaphore.release();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_bulkhead_basic_acquire_release() {
let bulkhead = Arc::new(BulkheadSemaphore::new(3));
assert_eq!(bulkhead.limit(), 3);
assert_eq!(bulkhead.acquired(), 0);
assert_eq!(bulkhead.available(), 3);
let guard1 = bulkhead.try_acquire();
assert!(guard1.is_some());
assert_eq!(bulkhead.acquired(), 1);
assert_eq!(bulkhead.available(), 2);
let guard2 = bulkhead.try_acquire();
assert!(guard2.is_some());
assert_eq!(bulkhead.acquired(), 2);
drop(guard1);
assert_eq!(bulkhead.acquired(), 1);
assert_eq!(bulkhead.available(), 2);
drop(guard2);
assert_eq!(bulkhead.acquired(), 0);
assert_eq!(bulkhead.available(), 3);
}
#[test]
fn test_bulkhead_at_capacity() {
let bulkhead = Arc::new(BulkheadSemaphore::new(2));
let guard1 = bulkhead.try_acquire().expect("Should acquire");
let guard2 = bulkhead.try_acquire().expect("Should acquire");
let guard3 = bulkhead.try_acquire();
assert!(guard3.is_none(), "Should not acquire when at capacity");
assert_eq!(bulkhead.acquired(), 2);
drop(guard1);
let guard4 = bulkhead.try_acquire();
assert!(guard4.is_some(), "Should acquire after release");
assert_eq!(bulkhead.acquired(), 2);
drop(guard2);
drop(guard4);
}
#[test]
fn test_bulkhead_concurrent_access() {
let bulkhead = Arc::new(BulkheadSemaphore::new(5));
let mut handles = vec![];
for _ in 0..10 {
let bulkhead_clone = Arc::clone(&bulkhead);
let handle = thread::spawn(move || {
if let Some(_guard) = bulkhead_clone.try_acquire() {
thread::sleep(std::time::Duration::from_millis(10));
true
} else {
false
}
});
handles.push(handle);
}
let mut acquired_count = 0;
for handle in handles {
if handle.join().unwrap() {
acquired_count += 1;
}
}
assert!(
acquired_count >= 5,
"At least 5 threads should acquire permits"
);
assert_eq!(bulkhead.acquired(), 0);
}
#[test]
#[should_panic(expected = "Bulkhead limit must be greater than 0")]
fn test_bulkhead_zero_limit() {
BulkheadSemaphore::new(0);
}
#[test]
fn test_bulkhead_guard_releases_on_panic() {
let bulkhead = Arc::new(BulkheadSemaphore::new(2));
let bulkhead_clone = Arc::clone(&bulkhead);
let result = std::panic::catch_unwind(move || {
let _guard = bulkhead_clone.try_acquire().unwrap();
panic!("Simulated panic");
});
assert!(result.is_err());
assert_eq!(bulkhead.acquired(), 0);
}
}