use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::{Notify, watch};
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
const NANOS_NONE: u64 = u64::MAX;
#[derive(Clone, Debug)]
pub struct Gate {
inner: Arc<Inner>,
}
#[derive(Debug)]
struct Inner {
graceful: NotifyOnce,
shutdown: NotifyOnce,
released: Notify,
count: AtomicUsize,
all_done: NotifyOnce,
grace_period_nanos: AtomicU64,
acquire_timeout: Duration,
max_count: Option<usize>,
}
#[derive(Debug)]
pub enum Error {
ShuttingDown, AcquireTimeout(Duration),
AtCapacity,
}
impl Gate {
pub fn new(
max_count: Option<usize>,
acquire_timeout: Duration,
) -> Self {
let inner = Inner {
graceful: NotifyOnce::default(),
shutdown: NotifyOnce::default(),
released: Notify::new(),
count: AtomicUsize::new(0),
all_done: NotifyOnce::default(),
grace_period_nanos: AtomicU64::new(NANOS_NONE),
acquire_timeout,
max_count,
};
Self { inner: Arc::new(inner) }
}
fn permit(&self) -> Permit {
Permit::new(self.clone())
}
pub fn grace_period(&self) -> Option<Duration> {
match self.inner.grace_period_nanos.load(Ordering::Acquire) {
NANOS_NONE => None,
nanos => Some(Duration::from_nanos(nanos)),
}
}
pub fn count(&self) -> usize {
self.inner.count.load(Ordering::SeqCst)
}
pub fn force_shutdown(&self) {
self.inner.shutdown.notify_waiters();
}
pub fn graceful_shutdown(
&self,
duration: Option<Duration>,
) {
let nanos = duration.map_or(NANOS_NONE, |d| d.as_nanos() as u64);
self.inner.grace_period_nanos.store(nanos, Ordering::Release);
self.inner.graceful.notify_waiters();
}
pub fn is_shutting_down(&self) -> bool {
self.inner.graceful.is_notified()
}
pub async fn wait_forced_shutdown(&self) {
self.inner.shutdown.notified().await;
}
pub async fn wait_graceful_shutdown(&self) {
self.inner.graceful.notified().await;
}
pub async fn enter(&self) -> Result<Permit, Error> {
let wait_timeout = self.inner.acquire_timeout;
let start = tokio::time::Instant::now();
loop {
if self.inner.graceful.is_notified() {
return Err(Error::ShuttingDown);
}
let count = self.inner.count.load(Ordering::SeqCst);
if let Some(max_count) = self.inner.max_count {
if count < max_count {
return Ok(self.permit());
}
if count == max_count {
debug!("Connection limit reached: {}/{} connections in use", count, max_count);
}
} else {
return Ok(self.permit());
}
let elapsed = start.elapsed();
if elapsed >= wait_timeout {
warn!(
"Connection acquire timeout after {:?}. Current: {}/{}",
wait_timeout,
count,
self.inner.max_count.unwrap_or(usize::MAX)
);
return Err(Error::AcquireTimeout(wait_timeout));
}
let remaining = wait_timeout - elapsed;
tokio::select! {
biased;
_ = self.inner.graceful.notified() => {
return Err(Error::ShuttingDown);
}
_ = self.inner.released.notified() => {
continue;
}
_ = sleep(remaining) => {
warn!(
"Connection acquire timeout after {:?}. Current: {}/{}",
wait_timeout,
count,
self.inner.max_count.unwrap_or(usize::MAX)
);
return Err(Error::AcquireTimeout(wait_timeout));
}
}
}
}
pub fn try_enter(&self) -> Result<Permit, Error> {
if self.inner.graceful.is_notified() {
return Err(Error::ShuttingDown);
}
if let Some(max) = self.inner.max_count {
let prev = self.inner.count.fetch_add(1, Ordering::Relaxed);
if prev >= max {
self.inner.count.fetch_sub(1, Ordering::Relaxed);
return Err(Error::AtCapacity);
}
} else {
self.inner.count.fetch_add(1, Ordering::Relaxed);
}
Ok(Permit { gate: self.clone() })
}
pub async fn wait_all_done(&self) {
if self.inner.count.load(Ordering::SeqCst) == 0 {
return;
}
let deadline = self.grace_period();
match deadline {
Some(duration) => tokio::select! {
biased;
_ = sleep(duration) => {
error!("⛔ Graceful timeout exceeded after {:?}; forcing shutdown", duration);
self.force_shutdown();
},
_ = self.inner.all_done.notified() => {
debug!("🍺 All connections finished before graceful timeout");
},
},
None => self.inner.all_done.notified().await,
}
}
}
pub struct Permit {
gate: Gate,
}
#[allow(unused)]
impl Permit {
fn new(gate: Gate) -> Self {
gate.inner.count.fetch_add(1, Ordering::SeqCst);
Self { gate }
}
pub async fn wait_graceful_shutdown(&self) {
self.gate.wait_graceful_shutdown().await
}
pub async fn wait_forced_shutdown(&self) {
self.gate.wait_forced_shutdown().await
}
pub fn is_shutting_down(&self) -> bool {
self.gate.is_shutting_down()
}
}
impl Drop for Permit {
fn drop(&mut self) {
let count = self.gate.inner.count.fetch_sub(1, Ordering::SeqCst) - 1;
if count == 0 && self.gate.inner.graceful.is_notified() {
self.gate.inner.all_done.notify_waiters();
}
if let Some(max_count) = self.gate.inner.max_count {
if count < max_count {
self.gate.inner.released.notify_waiters();
}
}
}
}
pub fn create_gate(
token: CancellationToken,
graceful_timeout: Option<Duration>,
max_count: Option<usize>,
acquire_timeout: Duration,
) -> Gate {
let gate = Gate::new(max_count, acquire_timeout);
let shutdown_gate = gate.clone();
tokio::spawn(async move {
token.cancelled().await;
shutdown_gate.graceful_shutdown(graceful_timeout);
});
gate
}
#[inline]
pub fn default_acquire_timeout() -> Duration {
Duration::from_millis(100)
}
#[derive(Debug)]
struct NotifyOnce {
tx: watch::Sender<bool>,
rx: watch::Receiver<bool>,
}
impl Default for NotifyOnce {
fn default() -> Self {
let (tx, rx) = watch::channel(false);
Self { tx, rx }
}
}
impl NotifyOnce {
fn notify_waiters(&self) {
self.tx.send_replace(true);
}
fn is_notified(&self) -> bool {
*self.rx.borrow()
}
async fn notified(&self) {
let mut rx = self.rx.clone();
loop {
if *rx.borrow_and_update() {
return;
}
if rx.changed().await.is_err() {
return;
}
}
}
}
#[cfg(test)]
mod tests {
use super::{Error, Gate};
use std::time::Duration;
#[test]
fn try_enter_counts_one_slot_per_permit() {
let gate = Gate::new(Some(2), Duration::from_millis(10));
let first = gate.try_enter().unwrap();
assert_eq!(gate.count(), 1);
let second = gate.try_enter().unwrap();
assert_eq!(gate.count(), 2);
assert!(matches!(gate.try_enter(), Err(Error::AtCapacity)));
drop(first);
assert_eq!(gate.count(), 1);
drop(second);
assert_eq!(gate.count(), 0);
}
}