use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AlreadyRunning;
impl std::fmt::Display for AlreadyRunning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("runtime is already running")
}
}
impl std::error::Error for AlreadyRunning {}
#[derive(Debug, Clone, Copy, serde::Serialize)]
pub struct RuntimeStatus {
pub running: bool,
}
struct State {
running: bool,
cancel: Option<CancellationToken>,
}
pub struct DomainRuntime<C> {
core: Arc<C>,
state: Arc<Mutex<State>>,
}
impl<C> Clone for DomainRuntime<C> {
fn clone(&self) -> Self {
Self {
core: Arc::clone(&self.core),
state: Arc::clone(&self.state),
}
}
}
impl<C> DomainRuntime<C> {
pub fn new(core: Arc<C>) -> Self {
Self {
core,
state: Arc::new(Mutex::new(State {
running: false,
cancel: None,
})),
}
}
pub fn core(&self) -> Arc<C> {
Arc::clone(&self.core)
}
pub async fn start<F>(&self, mk: F) -> Result<bool, AlreadyRunning>
where
F: FnOnce(CancellationToken) -> JoinHandle<()>,
{
let mut state = self.state.lock().await;
if state.running {
return Ok(false);
}
let token = CancellationToken::new();
let handle = mk(token.clone());
state.cancel = Some(token);
state.running = true;
drop(state);
let state = Arc::clone(&self.state);
tokio::spawn(async move {
let _ = handle.await;
let mut guard = state.lock().await;
guard.running = false;
guard.cancel = None;
});
Ok(true)
}
pub async fn stop(&self) -> bool {
let mut state = self.state.lock().await;
if let Some(token) = state.cancel.take() {
token.cancel();
state.running = false;
true
} else {
false
}
}
pub async fn status(&self) -> RuntimeStatus {
let state = self.state.lock().await;
RuntimeStatus {
running: state.running,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Core;
fn never_ending(token: CancellationToken) -> JoinHandle<()> {
tokio::spawn(async move {
token.cancelled().await;
})
}
#[tokio::test]
async fn start_then_status_running() {
let rt = DomainRuntime::new(Arc::new(Core));
assert!(!rt.status().await.running);
let started = rt.start(never_ending).await.unwrap();
assert!(started);
assert!(rt.status().await.running);
}
#[tokio::test]
async fn double_start_is_noop() {
let rt = DomainRuntime::new(Arc::new(Core));
assert!(rt.start(never_ending).await.unwrap());
assert!(!rt.start(never_ending).await.unwrap());
assert!(rt.status().await.running);
}
#[tokio::test]
async fn stop_clears_running() {
let rt = DomainRuntime::new(Arc::new(Core));
rt.start(never_ending).await.unwrap();
assert!(rt.stop().await);
assert!(!rt.status().await.running);
assert!(!rt.stop().await);
}
#[tokio::test]
async fn watcher_flips_running_when_loop_finishes() {
let rt = DomainRuntime::new(Arc::new(Core));
rt.start(|_token| tokio::spawn(async {})).await.unwrap();
for _ in 0..50 {
if !rt.status().await.running {
break;
}
tokio::task::yield_now().await;
}
assert!(!rt.status().await.running);
}
#[test]
fn already_running_display() {
assert_eq!(AlreadyRunning.to_string(), "runtime is already running");
}
}