use crate::AppSpindownToken;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use parking_lot::Mutex;
use scopeguard::defer;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::select;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
pub(crate) struct SpindownRegistry {
registry: Mutex<Vec<SpindownWorkload>>,
timeout: Duration,
}
impl SpindownRegistry {
pub(crate) fn new(timeout: Duration) -> Self {
Self {
registry: Mutex::new(Vec::new()),
timeout,
}
}
pub(crate) fn register(&self, name: &str) -> AppSpindownToken {
let workload = SpindownWorkload::new(name);
let token = workload.token();
let mut registry = self.registry.lock();
registry.push(workload);
token
}
}
impl SpindownRegistry {
pub(crate) async fn spun_down(&self) -> Result<usize, SpindownTimeout> {
info!("Spindown initiated");
let notify_in = Arc::new(Notify::new());
let notify_out = Arc::clone(¬ify_in);
let timeout = self.timeout;
let timer = tokio::spawn(async move {
tokio::time::sleep(timeout).await;
notify_in.notify_one();
});
defer! { timer.abort() }
let mut count = 0usize;
loop {
let workloads = {
let mut registry = self.registry.lock();
std::mem::take(&mut *registry)
};
count += workloads.len();
if workloads.is_empty() {
info!("Spindown completed");
return Ok(count);
} else {
info!(
"Waiting for {} registered workload(s) to complete",
workloads.len(),
);
}
let result = Self::spin_down_once(workloads, ¬ify_out).await;
match result {
Ok(()) => continue,
Err(error) => {
return Err(SpindownTimeout {
spun_down: count - error.timed_out,
timed_out: error.timed_out,
});
}
}
}
}
async fn spin_down_once(
workloads: Vec<SpindownWorkload>,
timeout: &Notify,
) -> Result<(), SpindownTimeout> {
let count = workloads.len();
let mut remaining = count;
let mut futures = workloads
.into_iter()
.map(SpindownWorkloadFuture::from)
.collect::<FuturesUnordered<_>>();
loop {
let state = select! {
biased;
_ = timeout.notified() => Self::receive_timeout(&futures),
result = futures.next() => Self::receive_future(result, &futures),
};
match state {
SpindownState::Ongoing => remaining -= 1,
SpindownState::Completed => return Ok(()),
SpindownState::TimedOut => {
return Err(SpindownTimeout {
spun_down: count - remaining,
timed_out: remaining,
});
}
}
}
}
fn receive_timeout(futures: &FuturesUnordered<SpindownWorkloadFuture>) -> SpindownState {
for future in futures {
error!(
workload = future.name.as_ref(),
"Did not complete in time during spindown",
);
}
warn!("Some workloads did not complete gracefully");
SpindownState::TimedOut
}
fn receive_future(
optional_outcome: Option<Arc<str>>,
futures: &FuturesUnordered<SpindownWorkloadFuture>,
) -> SpindownState {
match optional_outcome {
Some(workload) => {
info!(workload = workload.as_ref(), "Completed gracefully");
}
None => {
error!(
alert = true,
"Polled spindown futures while they are all already completed",
);
}
}
if futures.is_empty() {
info!("All workloads completed gracefully");
return SpindownState::Completed;
}
SpindownState::Ongoing
}
}
enum SpindownState {
Ongoing,
Completed,
TimedOut,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(dead_code)]
pub(crate) struct SpindownTimeout {
spun_down: usize,
timed_out: usize,
}
impl Display for SpindownTimeout {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("failed to fully spin down all workloads within the timeout: {} completed, {} timed out")
}
}
impl Error for SpindownTimeout {}
struct SpindownWorkload {
name: Arc<str>,
token: CancellationToken,
}
impl SpindownWorkload {
fn new(name: &str) -> Self {
Self {
name: Arc::from(name),
token: CancellationToken::new(),
}
}
fn token(&self) -> AppSpindownToken {
AppSpindownToken::new(self.token.clone())
}
}
impl From<SpindownWorkload> for SpindownWorkloadFuture {
fn from(workload: SpindownWorkload) -> Self {
let token_future = Box::pin(async move { workload.token.cancelled().await });
SpindownWorkloadFuture {
name: workload.name,
token_future,
}
}
}
struct SpindownWorkloadFuture {
name: Arc<str>,
token_future: Pin<Box<dyn Future<Output = ()>>>,
}
impl Future for SpindownWorkloadFuture {
type Output = Arc<str>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.token_future.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(_) => Poll::Ready(self.name.clone()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::time::Duration;
use tokio::time::Instant;
fn make_registry(timeout: Duration) -> SpindownRegistry {
SpindownRegistry {
registry: Mutex::new(Vec::new()),
timeout,
}
}
#[tokio::test]
async fn no_workloads() {
let registry = make_registry(Duration::from_secs(5));
let start = Instant::now();
let count = registry.spun_down().await.unwrap();
let elapsed = start.elapsed();
assert_eq!(count, 0);
assert!(
elapsed < Duration::from_millis(50),
"spun_down() should return immediately when no workloads are registered",
);
}
#[tokio::test]
async fn all_workloads_complete() {
let registry = make_registry(Duration::from_secs(5));
let token1 = registry.register("workload1");
let token2 = registry.register("workload2");
token1.punch_out();
token2.punch_out();
let start = Instant::now();
let count = registry.spun_down().await.unwrap();
let elapsed = start.elapsed();
assert_eq!(count, 2);
assert!(
elapsed < Duration::from_millis(50),
"spun_down() should complete quickly when all workloads complete",
);
}
#[tokio::test]
async fn timeout() {
let registry = make_registry(Duration::from_millis(100));
let _token = registry.register("workload_timeout");
let start = Instant::now();
let error = registry.spun_down().await.unwrap_err();
let elapsed = start.elapsed();
assert_eq!(
error,
SpindownTimeout {
spun_down: 0,
timed_out: 1
},
);
assert!(
elapsed >= Duration::from_millis(100),
"spun_down() should wait until timeout when workload doesn't complete",
);
}
#[tokio::test]
async fn token_drop_punch_out() {
let registry = make_registry(Duration::from_secs(5));
{
let _token = registry.register("dropped_workload");
}
let start = Instant::now();
let count = registry.spun_down().await.unwrap();
let elapsed = start.elapsed();
assert_eq!(count, 1);
assert!(
elapsed < Duration::from_millis(50),
"spun_down() should complete quickly when the token is dropped",
);
}
}