use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use super::service::{JanusService, RestartPolicy};
use crate::state::JanusState;
pub type ModuleStartFn = Box<
dyn Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
+ Send
+ Sync,
>;
pub struct ModuleAdapter {
name: String,
state: Arc<JanusState>,
start_fn: ModuleStartFn,
policy: RestartPolicy,
}
impl ModuleAdapter {
pub fn new<F>(name: &str, state: Arc<JanusState>, start_fn: F, policy: RestartPolicy) -> Self
where
F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
Self {
name: name.to_string(),
state,
start_fn: Box::new(start_fn),
policy,
}
}
pub fn on_failure<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
where
F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
Self::new(name, state, start_fn, RestartPolicy::OnFailure)
}
pub fn one_shot<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
where
F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
Self::new(name, state, start_fn, RestartPolicy::Never)
}
pub fn always_restart<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
where
F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
Self::new(name, state, start_fn, RestartPolicy::Always)
}
}
#[async_trait]
impl JanusService for ModuleAdapter {
fn name(&self) -> &str {
&self.name
}
fn restart_policy(&self) -> RestartPolicy {
self.policy
}
#[tracing::instrument(skip(self, cancel), fields(module = %self.name, policy = %self.policy))]
async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
if self.state.is_shutdown_requested() {
tracing::warn!(
module = %self.name,
"JanusState.shutdown_requested is already true at module start — \
the module may exit immediately. This indicates a stale shutdown \
flag from a previous lifecycle or a concurrent shutdown in progress."
);
}
let bridge_state = self.state.clone();
let bridge_cancel = cancel.clone();
let bridge_handle = tokio::spawn(async move {
bridge_cancel.cancelled().await;
tracing::info!(
"ModuleAdapter shutdown bridge: cancellation received, requesting state shutdown"
);
bridge_state.request_shutdown();
});
self.state
.register_module_health(&self.name, true, Some("starting".to_string()))
.await;
let state_clone = self.state.clone();
let module_result = (self.start_fn)(state_clone).await;
bridge_handle.abort();
match &module_result {
Ok(()) => {
self.state
.register_module_health(&self.name, true, Some("stopped".to_string()))
.await;
}
Err(e) => {
self.state
.register_module_health(&self.name, false, Some(format!("error: {e}")))
.await;
}
}
module_result.map_err(|e| anyhow::anyhow!("{}: {}", self.name, e))
}
}
impl std::fmt::Debug for ModuleAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModuleAdapter")
.field("name", &self.name)
.field("policy", &self.policy)
.finish_non_exhaustive()
}
}
pub struct ApiModuleAdapter {
inner: ModuleAdapter,
}
impl ApiModuleAdapter {
pub fn new<F>(state: Arc<JanusState>, start_fn: F) -> Self
where
F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
Self {
inner: ModuleAdapter::new("api", state, start_fn, RestartPolicy::Always),
}
}
}
#[async_trait]
impl JanusService for ApiModuleAdapter {
fn name(&self) -> &str {
self.inner.name()
}
fn restart_policy(&self) -> RestartPolicy {
RestartPolicy::Always
}
async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
self.inner.run(cancel).await
}
}
impl std::fmt::Debug for ApiModuleAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ApiModuleAdapter")
.field("inner", &self.inner)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
async fn test_state() -> Arc<JanusState> {
let config = crate::Config::default();
Arc::new(JanusState::new(config).await.unwrap())
}
#[tokio::test]
async fn test_module_adapter_clean_exit() {
let state = test_state().await;
let ran = Arc::new(AtomicU64::new(0));
let ran_clone = ran.clone();
let adapter = ModuleAdapter::on_failure("test-clean", state.clone(), move |_s| {
let ran = ran_clone.clone();
Box::pin(async move {
ran.fetch_add(1, Ordering::SeqCst);
Ok(())
})
});
let cancel = CancellationToken::new();
let svc: Box<dyn JanusService> = Box::new(adapter);
let result = svc.run(cancel).await;
assert!(result.is_ok());
assert_eq!(ran.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_module_adapter_error_propagation() {
let state = test_state().await;
let adapter = ModuleAdapter::on_failure("test-fail", state.clone(), |_s| {
Box::pin(async move { Err(crate::Error::Config("boom".into())) })
});
let cancel = CancellationToken::new();
let svc: Box<dyn JanusService> = Box::new(adapter);
let result = svc.run(cancel).await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("test-fail"),
"error should contain service name: {err_msg}"
);
assert!(
err_msg.contains("boom"),
"error should contain cause: {err_msg}"
);
}
#[tokio::test]
async fn test_module_adapter_cancellation_bridge() {
let state = test_state().await;
let adapter = ModuleAdapter::on_failure("test-cancel", state.clone(), |s| {
Box::pin(async move {
loop {
if s.is_shutdown_requested() {
return Ok(());
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
})
});
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
cancel_clone.cancel();
});
let svc: Box<dyn JanusService> = Box::new(adapter);
let result = svc.run(cancel).await;
assert!(result.is_ok());
assert!(state.is_shutdown_requested());
}
#[tokio::test]
async fn test_module_adapter_health_registration() {
let state = test_state().await;
let adapter = ModuleAdapter::on_failure("health-test", state.clone(), |_s| {
Box::pin(async move { Ok(()) })
});
let cancel = CancellationToken::new();
let svc: Box<dyn JanusService> = Box::new(adapter);
svc.run(cancel).await.unwrap();
let health = state.get_module_health().await;
let entry = health.iter().find(|h| h.name == "health-test");
assert!(entry.is_some(), "module should have registered health");
assert!(entry.unwrap().healthy);
}
#[tokio::test]
async fn test_module_adapter_health_on_error() {
let state = test_state().await;
let adapter = ModuleAdapter::on_failure("err-health", state.clone(), |_s| {
Box::pin(async move { Err(crate::Error::Config("kaboom".into())) })
});
let cancel = CancellationToken::new();
let svc: Box<dyn JanusService> = Box::new(adapter);
let _ = svc.run(cancel).await;
let health = state.get_module_health().await;
let entry = health.iter().find(|h| h.name == "err-health");
assert!(entry.is_some());
assert!(!entry.unwrap().healthy);
assert!(
entry
.unwrap()
.message
.as_deref()
.unwrap_or("")
.contains("kaboom")
);
}
#[test]
fn test_module_adapter_debug() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let state = test_state().await;
let adapter =
ModuleAdapter::on_failure("dbg", state, |_s| Box::pin(async move { Ok(()) }));
let dbg = format!("{:?}", adapter);
assert!(dbg.contains("ModuleAdapter"));
assert!(dbg.contains("dbg"));
});
}
#[test]
fn test_restart_policies() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let state = test_state().await;
let a1 = ModuleAdapter::on_failure("a", state.clone(), |_| Box::pin(async { Ok(()) }));
assert_eq!(a1.restart_policy(), RestartPolicy::OnFailure);
let a2 = ModuleAdapter::one_shot("b", state.clone(), |_| Box::pin(async { Ok(()) }));
assert_eq!(a2.restart_policy(), RestartPolicy::Never);
let a3 =
ModuleAdapter::always_restart("c", state.clone(), |_| Box::pin(async { Ok(()) }));
assert_eq!(a3.restart_policy(), RestartPolicy::Always);
});
}
#[tokio::test]
async fn test_api_module_adapter() {
let state = test_state().await;
let ran = Arc::new(AtomicU64::new(0));
let ran_clone = ran.clone();
let adapter = ApiModuleAdapter::new(state.clone(), move |_s| {
let ran = ran_clone.clone();
Box::pin(async move {
ran.fetch_add(1, Ordering::SeqCst);
Ok(())
})
});
assert_eq!(adapter.name(), "api");
assert_eq!(adapter.restart_policy(), RestartPolicy::Always);
let cancel = CancellationToken::new();
let svc: Box<dyn JanusService> = Box::new(adapter);
svc.run(cancel).await.unwrap();
assert_eq!(ran.load(Ordering::SeqCst), 1);
}
}