#![cfg(feature = "testing")]
use std::sync::Arc;
use async_trait::async_trait;
use dig_service::{
testing::TestNode, ExitReason, NodeLifecycle, RunContext, Service, ServiceError,
ShutdownReason, StartContext, StopContext,
};
struct RecordingNode {
log: Arc<parking_lot::Mutex<Vec<&'static str>>>,
fail_at: Option<&'static str>,
}
#[async_trait]
impl NodeLifecycle for RecordingNode {
const NAME: Option<&'static str> = Some("recording");
async fn pre_start(&self, _ctx: &StartContext<'_>) -> anyhow::Result<()> {
self.log.lock().push("pre_start");
if self.fail_at == Some("pre_start") {
anyhow::bail!("inject");
}
Ok(())
}
async fn on_start(&self, _ctx: &StartContext<'_>) -> anyhow::Result<()> {
self.log.lock().push("on_start");
if self.fail_at == Some("on_start") {
anyhow::bail!("inject");
}
Ok(())
}
async fn run(&self, ctx: RunContext) -> anyhow::Result<()> {
self.log.lock().push("run_enter");
if self.fail_at == Some("run") {
anyhow::bail!("inject");
}
ctx.shutdown.cancelled().await;
self.log.lock().push("run_exit");
Ok(())
}
async fn on_stop(&self, _ctx: &StopContext<'_>) -> anyhow::Result<()> {
self.log.lock().push("on_stop");
Ok(())
}
async fn post_stop(&self, _ctx: &StopContext<'_>) -> anyhow::Result<()> {
self.log.lock().push("post_stop");
Ok(())
}
}
#[tokio::test]
async fn happy_path_hook_order() {
let log: Arc<parking_lot::Mutex<Vec<&'static str>>> =
Arc::new(parking_lot::Mutex::new(Vec::new()));
let node = RecordingNode {
log: log.clone(),
fail_at: None,
};
let svc = Service::<_, (), ()>::new(node, (), ());
let handle = svc.handle();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
handle.request_shutdown(ShutdownReason::UserRequested);
});
let status = svc.start().await.expect("start");
let log = log.lock().clone();
assert_eq!(
log,
vec![
"pre_start",
"on_start",
"run_enter",
"run_exit",
"on_stop",
"post_stop"
]
);
assert!(matches!(
status.reason,
ExitReason::RequestedShutdown(ShutdownReason::UserRequested)
));
}
#[tokio::test]
async fn pre_start_failure_short_circuits() {
let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
let node = RecordingNode {
log: log.clone(),
fail_at: Some("pre_start"),
};
let err = Service::<_, (), ()>::new(node, (), ())
.start()
.await
.unwrap_err();
assert!(matches!(err, ServiceError::PreStartFailed(_)));
assert_eq!(log.lock().as_slice(), &["pre_start"]);
}
#[tokio::test]
async fn run_failure_still_runs_cleanup() {
let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
let node = RecordingNode {
log: log.clone(),
fail_at: Some("run"),
};
let err = Service::<_, (), ()>::new(node, (), ())
.start()
.await
.unwrap_err();
assert!(matches!(err, ServiceError::RunFailed(_)));
let log = log.lock().clone();
assert!(log.contains(&"on_stop"), "on_stop missing: {log:?}");
assert!(log.contains(&"post_stop"), "post_stop missing: {log:?}");
}
#[tokio::test]
async fn test_node_graceful_exit() {
let svc = Service::<_, (), ()>::new(TestNode::default(), (), ());
let handle = svc.handle();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
handle.request_shutdown(ShutdownReason::UserRequested);
});
let status = svc.start().await.unwrap();
assert!(status.is_graceful());
}