use crate::error::ClusterError;
use crate::sharding::Sharding;
use futures::future::BoxFuture;
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
pub struct SingletonContext {
inner: CancellationToken,
pub(crate) managed: Arc<AtomicBool>,
}
impl SingletonContext {
pub(crate) fn new(cancellation: CancellationToken, managed: Arc<AtomicBool>) -> Self {
Self {
inner: cancellation,
managed,
}
}
pub fn cancellation(&self) -> CancellationToken {
self.managed.store(true, Ordering::Release);
self.inner.clone()
}
pub fn is_cancelled(&self) -> bool {
self.inner.is_cancelled()
}
}
pub async fn register_singleton<F, Fut>(
sharding: &dyn Sharding,
name: &str,
run: F,
) -> Result<(), ClusterError>
where
F: Fn(SingletonContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), ClusterError>> + Send + 'static,
{
sharding
.register_singleton(
name,
None,
Arc::new(move |ctx| -> BoxFuture<'static, Result<(), ClusterError>> {
Box::pin(run(ctx))
}),
)
.await
}
pub struct SingletonBuilder<F> {
name: String,
run: F,
}
impl<F, Fut> SingletonBuilder<F>
where
F: Fn(SingletonContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), ClusterError>> + Send + 'static,
{
pub fn new(name: impl Into<String>, run: F) -> Self {
Self {
name: name.into(),
run,
}
}
pub async fn register(self, sharding: &dyn Sharding) -> Result<(), ClusterError> {
register_singleton(sharding, &self.name, self.run).await
}
}
pub fn singleton<F, Fut>(name: impl Into<String>, run: F) -> SingletonBuilder<F>
where
F: Fn(SingletonContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), ClusterError>> + Send + 'static,
{
SingletonBuilder::new(name, run)
}
pub type SingletonRun =
Arc<dyn Fn(SingletonContext) -> BoxFuture<'static, Result<(), ClusterError>> + Send + Sync>;
pub async fn register_singletons(
sharding: &dyn Sharding,
singletons: Vec<(String, SingletonRun)>,
) -> Result<(), ClusterError> {
for (name, run) in singletons {
sharding.register_singleton(&name, None, run).await?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ShardingConfig;
use crate::metrics::ClusterMetrics;
use crate::sharding_impl::ShardingImpl;
use crate::storage::noop_runners::NoopRunners;
use std::sync::Arc;
async fn test_sharding() -> Arc<ShardingImpl> {
let config = Arc::new(ShardingConfig::default());
let metrics = Arc::new(ClusterMetrics::unregistered());
let s =
ShardingImpl::new(config, Arc::new(NoopRunners), None, None, None, metrics).unwrap();
s.acquire_all_shards().await;
s
}
#[tokio::test]
async fn register_singleton_via_helper() {
let sharding = test_sharding().await;
let executed = Arc::new(std::sync::atomic::AtomicBool::new(false));
let executed_clone = executed.clone();
register_singleton(
sharding.as_ref(),
"test-singleton",
move |_ctx: SingletonContext| {
let e = executed_clone.clone();
async move {
e.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
},
)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(executed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn singleton_builder_api() {
let sharding = test_sharding().await;
let executed = Arc::new(std::sync::atomic::AtomicBool::new(false));
let executed_clone = executed.clone();
singleton("builder-singleton", move |_ctx: SingletonContext| {
let e = executed_clone.clone();
async move {
e.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
})
.register(sharding.as_ref())
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(executed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn register_multiple_singletons() {
let sharding = test_sharding().await;
let count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let mut singletons: Vec<(String, SingletonRun)> = Vec::new();
for i in 0..3 {
let c = count.clone();
singletons.push((
format!("singleton-{i}"),
Arc::new(
move |_ctx| -> BoxFuture<'static, Result<(), ClusterError>> {
let c = c.clone();
Box::pin(async move {
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(())
})
},
),
));
}
register_singletons(sharding.as_ref(), singletons)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[tokio::test]
async fn singleton_receives_cancellation_token() {
let sharding = test_sharding().await;
let token_received = Arc::new(std::sync::atomic::AtomicBool::new(false));
let token_received_clone = token_received.clone();
register_singleton(
sharding.as_ref(),
"cancellation-test",
move |ctx: SingletonContext| {
let t = token_received_clone.clone();
async move {
if !ctx.is_cancelled() {
t.store(true, std::sync::atomic::Ordering::SeqCst);
}
Ok(())
}
},
)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(token_received.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn singleton_not_managing_cancellation_is_force_cancelled() {
let sharding = test_sharding().await;
register_singleton(
sharding.as_ref(),
"unmanaged-singleton",
move |_ctx: SingletonContext| async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
#[allow(unreachable_code)]
Ok(())
},
)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let start = std::time::Instant::now();
sharding.shutdown().await.unwrap();
let elapsed = start.elapsed();
assert!(
elapsed < std::time::Duration::from_millis(500),
"shutdown took too long ({:?}), singleton may not have been force-cancelled",
elapsed
);
}
#[tokio::test]
async fn singleton_managing_cancellation_waits_for_graceful_shutdown() {
let sharding = test_sharding().await;
let cleanup_ran = Arc::new(std::sync::atomic::AtomicBool::new(false));
let cleanup_ran_clone = cleanup_ran.clone();
register_singleton(
sharding.as_ref(),
"managed-singleton",
move |ctx: SingletonContext| {
let cleanup = cleanup_ran_clone.clone();
async move {
let cancel = ctx.cancellation();
loop {
tokio::select! {
_ = cancel.cancelled() => {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
cleanup.store(true, std::sync::atomic::Ordering::SeqCst);
break;
}
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {
}
}
}
Ok(())
}
},
)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
sharding.shutdown().await.unwrap();
assert!(
cleanup_ran.load(std::sync::atomic::Ordering::SeqCst),
"cleanup should run when singleton manages cancellation"
);
}
}