Skip to main content

jflow_core/supervisor/adapters/
mod.rs

1//! Service adapters for integrating existing Janus modules with the
2//! [`JanusSupervisor`](super::JanusSupervisor).
3//!
4//! The existing Janus services (Forward, Backward, CNS, Data) expose a
5//! `start_module(state: Arc<JanusState>) -> Result<()>` entry point that
6//! internally polls `state.is_shutdown_requested()` for lifecycle control.
7//!
8//! The supervisor, however, manages services through the [`JanusService`]
9//! trait which passes a [`CancellationToken`] for shutdown signalling.
10//!
11//! This module provides [`ModuleAdapter`] — a bridge that:
12//!
13//! 1. Wraps any `start_module`-style async function into a [`JanusService`]
14//! 2. Bridges the supervisor's `CancellationToken` → `JanusState::request_shutdown()`
15//! 3. Propagates errors back to the supervisor for restart decisions
16//!
17//! # Example
18//!
19//! ```rust,ignore
20//! use janus_core::supervisor::adapters::ModuleAdapter;
21//! use janus_core::supervisor::RestartPolicy;
22//!
23//! let data_adapter = ModuleAdapter::new(
24//!     "data",
25//!     state.clone(),
26//!     |s| Box::pin(janus_data::start_module(s)),
27//!     RestartPolicy::OnFailure,
28//! );
29//!
30//! supervisor.spawn_service(Box::new(data_adapter));
31//! ```
32
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36
37use async_trait::async_trait;
38use tokio_util::sync::CancellationToken;
39
40use super::service::{JanusService, RestartPolicy};
41use crate::state::JanusState;
42
43// ---------------------------------------------------------------------------
44// ModuleStartFn — type alias for the start_module function signature
45// ---------------------------------------------------------------------------
46
47/// A boxed async function that takes `Arc<JanusState>` and returns a `Result`.
48///
49/// This matches the signature of every existing Janus service's
50/// `start_module()` function.
51pub type ModuleStartFn = Box<
52    dyn Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
53        + Send
54        + Sync,
55>;
56
57// ---------------------------------------------------------------------------
58// ModuleAdapter
59// ---------------------------------------------------------------------------
60
61/// Adapts an existing `start_module(Arc<JanusState>) -> Result<()>` function
62/// into a [`JanusService`] that the supervisor can manage.
63///
64/// ## Shutdown Bridging
65///
66/// When the supervisor cancels this service's [`CancellationToken`], the
67/// adapter calls `state.request_shutdown()` to signal the inner module
68/// through its existing shutdown pathway. This avoids having to rewrite
69/// every service to accept a `CancellationToken` directly — they continue
70/// polling `state.is_shutdown_requested()` as before.
71///
72/// ### Shared-State Shutdown Invariant
73///
74/// Because `JanusState` is shared across all modules, `request_shutdown()`
75/// sets a **global** `AtomicBool`.  This is safe under the current
76/// supervisor model because:
77///
78/// 1. The supervisor propagates cancellation to **all** services at once
79///    (child tokens of the same root), so if one bridge fires they all do.
80/// 2. When a service fails on its own (returns `Err`), the bridge task is
81///    aborted *before* it can call `request_shutdown()`, keeping the flag
82///    `false` for the restart.
83/// 3. The supervisor checks `cancel.is_cancelled()` before each restart
84///    attempt, so a stale `true` flag never leads to a restart loop.
85///
86/// If per-module independent restart semantics are ever needed (e.g.,
87/// restarting one module while others keep running), this must be
88/// refactored to use a per-adapter shutdown signal instead of the global
89/// `JanusState` flag.
90///
91/// ## Error Propagation
92///
93/// Errors returned by the inner `start_module` are converted to
94/// `anyhow::Error` and bubbled up to the supervisor, which then applies
95/// the configured [`RestartPolicy`].
96pub struct ModuleAdapter {
97    /// Human-readable name for logging and metrics.
98    name: String,
99
100    /// Shared application state passed to the module on each (re)start.
101    state: Arc<JanusState>,
102
103    /// The `start_module`-style function to call.
104    start_fn: ModuleStartFn,
105
106    /// Restart policy for this module.
107    policy: RestartPolicy,
108}
109
110impl ModuleAdapter {
111    /// Create a new adapter.
112    ///
113    /// # Arguments
114    ///
115    /// * `name` — Service name for the supervisor (e.g., `"forward"`,
116    ///   `"data"`, `"cns"`).
117    /// * `state` — Shared `JanusState` that the module uses for
118    ///   configuration, health reporting, and shutdown signalling.
119    /// * `start_fn` — A closure or function pointer matching the
120    ///   `start_module` signature. Must be `Send + Sync + 'static`.
121    /// * `policy` — How the supervisor should handle failures.
122    ///
123    /// # Example
124    ///
125    /// ```rust,ignore
126    /// let adapter = ModuleAdapter::new(
127    ///     "forward",
128    ///     state.clone(),
129    ///     |s| Box::pin(janus_forward::start_module(s)),
130    ///     RestartPolicy::OnFailure,
131    /// );
132    /// ```
133    pub fn new<F>(name: &str, state: Arc<JanusState>, start_fn: F, policy: RestartPolicy) -> Self
134    where
135        F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
136            + Send
137            + Sync
138            + 'static,
139    {
140        Self {
141            name: name.to_string(),
142            state,
143            start_fn: Box::new(start_fn),
144            policy,
145        }
146    }
147
148    /// Create an adapter with the default [`RestartPolicy::OnFailure`].
149    pub fn on_failure<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
150    where
151        F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
152            + Send
153            + Sync
154            + 'static,
155    {
156        Self::new(name, state, start_fn, RestartPolicy::OnFailure)
157    }
158
159    /// Create an adapter that never restarts (one-shot module).
160    pub fn one_shot<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
161    where
162        F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
163            + Send
164            + Sync
165            + 'static,
166    {
167        Self::new(name, state, start_fn, RestartPolicy::Never)
168    }
169
170    /// Create an adapter that always restarts (even on clean exit).
171    pub fn always_restart<F>(name: &str, state: Arc<JanusState>, start_fn: F) -> Self
172    where
173        F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
174            + Send
175            + Sync
176            + 'static,
177    {
178        Self::new(name, state, start_fn, RestartPolicy::Always)
179    }
180}
181
182#[async_trait]
183impl JanusService for ModuleAdapter {
184    fn name(&self) -> &str {
185        &self.name
186    }
187
188    fn restart_policy(&self) -> RestartPolicy {
189        self.policy
190    }
191
192    #[tracing::instrument(skip(self, cancel), fields(module = %self.name, policy = %self.policy))]
193    async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
194        // Defensive check: if `shutdown_requested` is already `true` on
195        // the shared JanusState, the module will likely exit immediately.
196        // This should never happen during a normal restart cycle (the
197        // supervisor checks `cancel.is_cancelled()` first, and the bridge
198        // task is aborted on natural module exit), but we log a warning
199        // so the condition is diagnosable if the invariant is ever broken.
200        if self.state.is_shutdown_requested() {
201            tracing::warn!(
202                module = %self.name,
203                "JanusState.shutdown_requested is already true at module start — \
204                 the module may exit immediately. This indicates a stale shutdown \
205                 flag from a previous lifecycle or a concurrent shutdown in progress."
206            );
207        }
208
209        // Spawn a bridge task that watches the CancellationToken and
210        // translates it into a JanusState shutdown request.
211        let bridge_state = self.state.clone();
212        let bridge_cancel = cancel.clone();
213        let bridge_handle = tokio::spawn(async move {
214            bridge_cancel.cancelled().await;
215            tracing::info!(
216                "ModuleAdapter shutdown bridge: cancellation received, requesting state shutdown"
217            );
218            bridge_state.request_shutdown();
219        });
220
221        // Register health before starting
222        self.state
223            .register_module_health(&self.name, true, Some("starting".to_string()))
224            .await;
225
226        // Run the module
227        let state_clone = self.state.clone();
228        let module_result = (self.start_fn)(state_clone).await;
229
230        // Abort the bridge task if the module exited on its own
231        // (e.g., due to an internal error, not a cancellation).
232        bridge_handle.abort();
233
234        // Update health based on result
235        match &module_result {
236            Ok(()) => {
237                self.state
238                    .register_module_health(&self.name, true, Some("stopped".to_string()))
239                    .await;
240            }
241            Err(e) => {
242                self.state
243                    .register_module_health(&self.name, false, Some(format!("error: {e}")))
244                    .await;
245            }
246        }
247
248        module_result.map_err(|e| anyhow::anyhow!("{}: {}", self.name, e))
249    }
250}
251
252// We can't derive Debug because ModuleStartFn contains a trait object,
253// but we can implement it manually for diagnostics.
254impl std::fmt::Debug for ModuleAdapter {
255    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        f.debug_struct("ModuleAdapter")
257            .field("name", &self.name)
258            .field("policy", &self.policy)
259            .finish_non_exhaustive()
260    }
261}
262
263// ---------------------------------------------------------------------------
264// ApiModuleAdapter — specialised for the always-on API module
265// ---------------------------------------------------------------------------
266
267/// A convenience wrapper for the API module which is always started
268/// immediately (it doesn't wait for `start_services()`) and should
269/// always be restarted on failure.
270///
271/// This is functionally identical to `ModuleAdapter::always_restart()`
272/// but has a dedicated type for clarity in the supervisor setup code.
273pub struct ApiModuleAdapter {
274    inner: ModuleAdapter,
275}
276
277impl ApiModuleAdapter {
278    /// Create a new API module adapter.
279    pub fn new<F>(state: Arc<JanusState>, start_fn: F) -> Self
280    where
281        F: Fn(Arc<JanusState>) -> Pin<Box<dyn Future<Output = crate::Result<()>> + Send>>
282            + Send
283            + Sync
284            + 'static,
285    {
286        Self {
287            inner: ModuleAdapter::new("api", state, start_fn, RestartPolicy::Always),
288        }
289    }
290}
291
292#[async_trait]
293impl JanusService for ApiModuleAdapter {
294    fn name(&self) -> &str {
295        self.inner.name()
296    }
297
298    fn restart_policy(&self) -> RestartPolicy {
299        RestartPolicy::Always
300    }
301
302    async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
303        self.inner.run(cancel).await
304    }
305}
306
307impl std::fmt::Debug for ApiModuleAdapter {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        f.debug_struct("ApiModuleAdapter")
310            .field("inner", &self.inner)
311            .finish()
312    }
313}
314
315// ===========================================================================
316// Tests
317// ===========================================================================
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use std::sync::atomic::{AtomicU64, Ordering};
323
324    /// Helper: create a JanusState for testing.
325    async fn test_state() -> Arc<JanusState> {
326        let config = crate::Config::default();
327        Arc::new(JanusState::new(config).await.unwrap())
328    }
329
330    #[tokio::test]
331    async fn test_module_adapter_clean_exit() {
332        let state = test_state().await;
333
334        let ran = Arc::new(AtomicU64::new(0));
335        let ran_clone = ran.clone();
336
337        let adapter = ModuleAdapter::on_failure("test-clean", state.clone(), move |_s| {
338            let ran = ran_clone.clone();
339            Box::pin(async move {
340                ran.fetch_add(1, Ordering::SeqCst);
341                Ok(())
342            })
343        });
344
345        let cancel = CancellationToken::new();
346        let svc: Box<dyn JanusService> = Box::new(adapter);
347        let result = svc.run(cancel).await;
348
349        assert!(result.is_ok());
350        assert_eq!(ran.load(Ordering::SeqCst), 1);
351    }
352
353    #[tokio::test]
354    async fn test_module_adapter_error_propagation() {
355        let state = test_state().await;
356
357        let adapter = ModuleAdapter::on_failure("test-fail", state.clone(), |_s| {
358            Box::pin(async move { Err(crate::Error::Config("boom".into())) })
359        });
360
361        let cancel = CancellationToken::new();
362        let svc: Box<dyn JanusService> = Box::new(adapter);
363        let result = svc.run(cancel).await;
364
365        assert!(result.is_err());
366        let err_msg = result.unwrap_err().to_string();
367        assert!(
368            err_msg.contains("test-fail"),
369            "error should contain service name: {err_msg}"
370        );
371        assert!(
372            err_msg.contains("boom"),
373            "error should contain cause: {err_msg}"
374        );
375    }
376
377    #[tokio::test]
378    async fn test_module_adapter_cancellation_bridge() {
379        let state = test_state().await;
380
381        let adapter = ModuleAdapter::on_failure("test-cancel", state.clone(), |s| {
382            Box::pin(async move {
383                // Simulate a module that polls shutdown
384                loop {
385                    if s.is_shutdown_requested() {
386                        return Ok(());
387                    }
388                    tokio::time::sleep(std::time::Duration::from_millis(10)).await;
389                }
390            })
391        });
392
393        let cancel = CancellationToken::new();
394        let cancel_clone = cancel.clone();
395
396        // Cancel after a short delay
397        tokio::spawn(async move {
398            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
399            cancel_clone.cancel();
400        });
401
402        let svc: Box<dyn JanusService> = Box::new(adapter);
403        let result = svc.run(cancel).await;
404
405        assert!(result.is_ok());
406        assert!(state.is_shutdown_requested());
407    }
408
409    #[tokio::test]
410    async fn test_module_adapter_health_registration() {
411        let state = test_state().await;
412
413        let adapter = ModuleAdapter::on_failure("health-test", state.clone(), |_s| {
414            Box::pin(async move { Ok(()) })
415        });
416
417        let cancel = CancellationToken::new();
418        let svc: Box<dyn JanusService> = Box::new(adapter);
419        svc.run(cancel).await.unwrap();
420
421        let health = state.get_module_health().await;
422        let entry = health.iter().find(|h| h.name == "health-test");
423        assert!(entry.is_some(), "module should have registered health");
424        assert!(entry.unwrap().healthy);
425    }
426
427    #[tokio::test]
428    async fn test_module_adapter_health_on_error() {
429        let state = test_state().await;
430
431        let adapter = ModuleAdapter::on_failure("err-health", state.clone(), |_s| {
432            Box::pin(async move { Err(crate::Error::Config("kaboom".into())) })
433        });
434
435        let cancel = CancellationToken::new();
436        let svc: Box<dyn JanusService> = Box::new(adapter);
437        let _ = svc.run(cancel).await;
438
439        let health = state.get_module_health().await;
440        let entry = health.iter().find(|h| h.name == "err-health");
441        assert!(entry.is_some());
442        assert!(!entry.unwrap().healthy);
443        assert!(
444            entry
445                .unwrap()
446                .message
447                .as_deref()
448                .unwrap_or("")
449                .contains("kaboom")
450        );
451    }
452
453    #[test]
454    fn test_module_adapter_debug() {
455        // Just ensure Debug doesn't panic
456        let rt = tokio::runtime::Runtime::new().unwrap();
457        rt.block_on(async {
458            let state = test_state().await;
459            let adapter =
460                ModuleAdapter::on_failure("dbg", state, |_s| Box::pin(async move { Ok(()) }));
461            let dbg = format!("{:?}", adapter);
462            assert!(dbg.contains("ModuleAdapter"));
463            assert!(dbg.contains("dbg"));
464        });
465    }
466
467    #[test]
468    fn test_restart_policies() {
469        let rt = tokio::runtime::Runtime::new().unwrap();
470        rt.block_on(async {
471            let state = test_state().await;
472
473            let a1 = ModuleAdapter::on_failure("a", state.clone(), |_| Box::pin(async { Ok(()) }));
474            assert_eq!(a1.restart_policy(), RestartPolicy::OnFailure);
475
476            let a2 = ModuleAdapter::one_shot("b", state.clone(), |_| Box::pin(async { Ok(()) }));
477            assert_eq!(a2.restart_policy(), RestartPolicy::Never);
478
479            let a3 =
480                ModuleAdapter::always_restart("c", state.clone(), |_| Box::pin(async { Ok(()) }));
481            assert_eq!(a3.restart_policy(), RestartPolicy::Always);
482        });
483    }
484
485    #[tokio::test]
486    async fn test_api_module_adapter() {
487        let state = test_state().await;
488        let ran = Arc::new(AtomicU64::new(0));
489        let ran_clone = ran.clone();
490
491        let adapter = ApiModuleAdapter::new(state.clone(), move |_s| {
492            let ran = ran_clone.clone();
493            Box::pin(async move {
494                ran.fetch_add(1, Ordering::SeqCst);
495                Ok(())
496            })
497        });
498
499        assert_eq!(adapter.name(), "api");
500        assert_eq!(adapter.restart_policy(), RestartPolicy::Always);
501
502        let cancel = CancellationToken::new();
503        let svc: Box<dyn JanusService> = Box::new(adapter);
504        svc.run(cancel).await.unwrap();
505
506        assert_eq!(ran.load(Ordering::SeqCst), 1);
507    }
508}