jflow_core/supervisor/adapters/
mod.rs1use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36
37use async_trait::async_trait;
38use tokio_util::sync::CancellationToken;
39
40use super::service::{JanusService, RestartPolicy};
41use crate::state::JanusState;
42
43pub type ModuleStartFn = Box<
52 dyn Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
53 + Send
54 + Sync,
55>;
56
57pub struct ModuleAdapter {
97 name: String,
99
100 state: Arc<JanusState>,
102
103 start_fn: ModuleStartFn,
105
106 policy: RestartPolicy,
108}
109
110impl ModuleAdapter {
111 pub fn new<F>(name: &str, state: Arc<JanusState>, start_fn: F, policy: RestartPolicy) -> Self
134 where
135 F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
136 + Send
137 + Sync
138 + 'static,
139 {
140 Self {
141 name: name.to_string(),
142 state,
143 start_fn: Box::new(start_fn),
144 policy,
145 }
146 }
147
148 pub fn on_failure<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
150 where
151 F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
152 + Send
153 + Sync
154 + 'static,
155 {
156 Self::new(name, state, start_fn, RestartPolicy::OnFailure)
157 }
158
159 pub fn one_shot<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
161 where
162 F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
163 + Send
164 + Sync
165 + 'static,
166 {
167 Self::new(name, state, start_fn, RestartPolicy::Never)
168 }
169
170 pub fn always_restart<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
172 where
173 F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
174 + Send
175 + Sync
176 + 'static,
177 {
178 Self::new(name, state, start_fn, RestartPolicy::Always)
179 }
180}
181
182#[async_trait]
183impl JanusService for ModuleAdapter {
184 fn name(&self) -> &str {
185 &self.name
186 }
187
188 fn restart_policy(&self) -> RestartPolicy {
189 self.policy
190 }
191
192 #[tracing::instrument(skip(self, cancel), fields(module = %self.name, policy = %self.policy))]
193 async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
194 if self.state.is_shutdown_requested() {
201 tracing::warn!(
202 module = %self.name,
203 "JanusState.shutdown_requested is already true at module start — \
204 the module may exit immediately. This indicates a stale shutdown \
205 flag from a previous lifecycle or a concurrent shutdown in progress."
206 );
207 }
208
209 let bridge_state = self.state.clone();
212 let bridge_cancel = cancel.clone();
213 let bridge_handle = tokio::spawn(async move {
214 bridge_cancel.cancelled().await;
215 tracing::info!(
216 "ModuleAdapter shutdown bridge: cancellation received, requesting state shutdown"
217 );
218 bridge_state.request_shutdown();
219 });
220
221 self.state
223 .register_module_health(&self.name, true, Some("starting".to_string()))
224 .await;
225
226 let state_clone = self.state.clone();
228 let module_result = (self.start_fn)(state_clone).await;
229
230 bridge_handle.abort();
233
234 match &module_result {
236 Ok(()) => {
237 self.state
238 .register_module_health(&self.name, true, Some("stopped".to_string()))
239 .await;
240 }
241 Err(e) => {
242 self.state
243 .register_module_health(&self.name, false, Some(format!("error: {e}")))
244 .await;
245 }
246 }
247
248 module_result.map_err(|e| anyhow::anyhow!("{}: {}", self.name, e))
249 }
250}
251
252impl std::fmt::Debug for ModuleAdapter {
255 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 f.debug_struct("ModuleAdapter")
257 .field("name", &self.name)
258 .field("policy", &self.policy)
259 .finish_non_exhaustive()
260 }
261}
262
263pub struct ApiModuleAdapter {
274 inner: ModuleAdapter,
275}
276
277impl ApiModuleAdapter {
278 pub fn new<F>(state: Arc<JanusState>, start_fn: F) -> Self
280 where
281 F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
282 + Send
283 + Sync
284 + 'static,
285 {
286 Self {
287 inner: ModuleAdapter::new("api", state, start_fn, RestartPolicy::Always),
288 }
289 }
290}
291
292#[async_trait]
293impl JanusService for ApiModuleAdapter {
294 fn name(&self) -> &str {
295 self.inner.name()
296 }
297
298 fn restart_policy(&self) -> RestartPolicy {
299 RestartPolicy::Always
300 }
301
302 async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
303 self.inner.run(cancel).await
304 }
305}
306
307impl std::fmt::Debug for ApiModuleAdapter {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 f.debug_struct("ApiModuleAdapter")
310 .field("inner", &self.inner)
311 .finish()
312 }
313}
314
315#[cfg(test)]
320mod tests {
321 use super::*;
322 use std::sync::atomic::{AtomicU64, Ordering};
323
324 async fn test_state() -> Arc<JanusState> {
326 let config = crate::Config::default();
327 Arc::new(JanusState::new(config).await.unwrap())
328 }
329
330 #[tokio::test]
331 async fn test_module_adapter_clean_exit() {
332 let state = test_state().await;
333
334 let ran = Arc::new(AtomicU64::new(0));
335 let ran_clone = ran.clone();
336
337 let adapter = ModuleAdapter::on_failure("test-clean", state.clone(), move |_s| {
338 let ran = ran_clone.clone();
339 Box::pin(async move {
340 ran.fetch_add(1, Ordering::SeqCst);
341 Ok(())
342 })
343 });
344
345 let cancel = CancellationToken::new();
346 let svc: Box<dyn JanusService> = Box::new(adapter);
347 let result = svc.run(cancel).await;
348
349 assert!(result.is_ok());
350 assert_eq!(ran.load(Ordering::SeqCst), 1);
351 }
352
353 #[tokio::test]
354 async fn test_module_adapter_error_propagation() {
355 let state = test_state().await;
356
357 let adapter = ModuleAdapter::on_failure("test-fail", state.clone(), |_s| {
358 Box::pin(async move { Err(crate::Error::Config("boom".into())) })
359 });
360
361 let cancel = CancellationToken::new();
362 let svc: Box<dyn JanusService> = Box::new(adapter);
363 let result = svc.run(cancel).await;
364
365 assert!(result.is_err());
366 let err_msg = result.unwrap_err().to_string();
367 assert!(
368 err_msg.contains("test-fail"),
369 "error should contain service name: {err_msg}"
370 );
371 assert!(
372 err_msg.contains("boom"),
373 "error should contain cause: {err_msg}"
374 );
375 }
376
377 #[tokio::test]
378 async fn test_module_adapter_cancellation_bridge() {
379 let state = test_state().await;
380
381 let adapter = ModuleAdapter::on_failure("test-cancel", state.clone(), |s| {
382 Box::pin(async move {
383 loop {
385 if s.is_shutdown_requested() {
386 return Ok(());
387 }
388 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
389 }
390 })
391 });
392
393 let cancel = CancellationToken::new();
394 let cancel_clone = cancel.clone();
395
396 tokio::spawn(async move {
398 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
399 cancel_clone.cancel();
400 });
401
402 let svc: Box<dyn JanusService> = Box::new(adapter);
403 let result = svc.run(cancel).await;
404
405 assert!(result.is_ok());
406 assert!(state.is_shutdown_requested());
407 }
408
409 #[tokio::test]
410 async fn test_module_adapter_health_registration() {
411 let state = test_state().await;
412
413 let adapter = ModuleAdapter::on_failure("health-test", state.clone(), |_s| {
414 Box::pin(async move { Ok(()) })
415 });
416
417 let cancel = CancellationToken::new();
418 let svc: Box<dyn JanusService> = Box::new(adapter);
419 svc.run(cancel).await.unwrap();
420
421 let health = state.get_module_health().await;
422 let entry = health.iter().find(|h| h.name == "health-test");
423 assert!(entry.is_some(), "module should have registered health");
424 assert!(entry.unwrap().healthy);
425 }
426
427 #[tokio::test]
428 async fn test_module_adapter_health_on_error() {
429 let state = test_state().await;
430
431 let adapter = ModuleAdapter::on_failure("err-health", state.clone(), |_s| {
432 Box::pin(async move { Err(crate::Error::Config("kaboom".into())) })
433 });
434
435 let cancel = CancellationToken::new();
436 let svc: Box<dyn JanusService> = Box::new(adapter);
437 let _ = svc.run(cancel).await;
438
439 let health = state.get_module_health().await;
440 let entry = health.iter().find(|h| h.name == "err-health");
441 assert!(entry.is_some());
442 assert!(!entry.unwrap().healthy);
443 assert!(
444 entry
445 .unwrap()
446 .message
447 .as_deref()
448 .unwrap_or("")
449 .contains("kaboom")
450 );
451 }
452
453 #[test]
454 fn test_module_adapter_debug() {
455 let rt = tokio::runtime::Runtime::new().unwrap();
457 rt.block_on(async {
458 let state = test_state().await;
459 let adapter =
460 ModuleAdapter::on_failure("dbg", state, |_s| Box::pin(async move { Ok(()) }));
461 let dbg = format!("{:?}", adapter);
462 assert!(dbg.contains("ModuleAdapter"));
463 assert!(dbg.contains("dbg"));
464 });
465 }
466
467 #[test]
468 fn test_restart_policies() {
469 let rt = tokio::runtime::Runtime::new().unwrap();
470 rt.block_on(async {
471 let state = test_state().await;
472
473 let a1 = ModuleAdapter::on_failure("a", state.clone(), |_| Box::pin(async { Ok(()) }));
474 assert_eq!(a1.restart_policy(), RestartPolicy::OnFailure);
475
476 let a2 = ModuleAdapter::one_shot("b", state.clone(), |_| Box::pin(async { Ok(()) }));
477 assert_eq!(a2.restart_policy(), RestartPolicy::Never);
478
479 let a3 =
480 ModuleAdapter::always_restart("c", state.clone(), |_| Box::pin(async { Ok(()) }));
481 assert_eq!(a3.restart_policy(), RestartPolicy::Always);
482 });
483 }
484
485 #[tokio::test]
486 async fn test_api_module_adapter() {
487 let state = test_state().await;
488 let ran = Arc::new(AtomicU64::new(0));
489 let ran_clone = ran.clone();
490
491 let adapter = ApiModuleAdapter::new(state.clone(), move |_s| {
492 let ran = ran_clone.clone();
493 Box::pin(async move {
494 ran.fetch_add(1, Ordering::SeqCst);
495 Ok(())
496 })
497 });
498
499 assert_eq!(adapter.name(), "api");
500 assert_eq!(adapter.restart_policy(), RestartPolicy::Always);
501
502 let cancel = CancellationToken::new();
503 let svc: Box<dyn JanusService> = Box::new(adapter);
504 svc.run(cancel).await.unwrap();
505
506 assert_eq!(ran.load(Ordering::SeqCst), 1);
507 }
508}