Skip to main content

forge_core/daemon/
context.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::{Mutex, watch};
5use tracing::Span;
6use uuid::Uuid;
7
8use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
9use crate::function::{JobDispatch, KvHandle, WorkflowDispatch};
10use crate::http::CircuitBreakerClient;
11
12/// Context available to daemon handlers.
13#[non_exhaustive]
14pub struct DaemonContext {
15    pub daemon_name: String,
16    pub instance_id: Uuid,
17    db_pool: sqlx::PgPool,
18    http_client: CircuitBreakerClient,
19    /// `None` means unlimited.
20    http_timeout: Option<Duration>,
21    /// Wrapped in `Mutex` for interior mutability across async boundaries.
22    shutdown_rx: Mutex<watch::Receiver<bool>>,
23    job_dispatch: Option<Arc<dyn JobDispatch>>,
24    workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
25    env_provider: Arc<dyn EnvProvider>,
26    span: Span,
27    kv: Option<Arc<dyn KvHandle>>,
28}
29
30impl DaemonContext {
31    /// Create a new daemon context.
32    pub fn new(
33        daemon_name: String,
34        instance_id: Uuid,
35        db_pool: sqlx::PgPool,
36        http_client: CircuitBreakerClient,
37        shutdown_rx: watch::Receiver<bool>,
38    ) -> Self {
39        Self {
40            daemon_name,
41            instance_id,
42            db_pool,
43            http_client,
44            http_timeout: None,
45            shutdown_rx: Mutex::new(shutdown_rx),
46            job_dispatch: None,
47            workflow_dispatch: None,
48            env_provider: Arc::new(RealEnvProvider::new()),
49            span: Span::current(),
50            kv: None,
51        }
52    }
53
54    /// Attach a KV store handle. Called by the runtime before handing the
55    /// context to the handler.
56    pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
57        self.kv = Some(kv);
58        self
59    }
60
61    /// Access the KV store.
62    pub fn kv(&self) -> crate::error::Result<&dyn KvHandle> {
63        self.kv
64            .as_deref()
65            .ok_or_else(|| crate::error::ForgeError::internal("KV store not available"))
66    }
67
68    /// Set job dispatcher.
69    pub fn with_job_dispatch(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
70        self.job_dispatch = Some(dispatcher);
71        self
72    }
73
74    /// Set workflow dispatcher.
75    pub fn with_workflow_dispatch(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
76        self.workflow_dispatch = Some(dispatcher);
77        self
78    }
79
80    /// Set environment provider.
81    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
82        self.env_provider = provider;
83        self
84    }
85
86    pub fn db(&self) -> crate::function::ForgeDb {
87        crate::function::ForgeDb::from_pool(&self.db_pool)
88    }
89
90    /// Get a `DbConn` for use in shared helper functions.
91    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
92        crate::function::DbConn::Pool(self.db_pool.clone())
93    }
94
95    /// Acquire a connection compatible with sqlx compile-time checked macros.
96    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
97        Ok(crate::function::ForgeConn::Pool(
98            self.db_pool.acquire().await?,
99        ))
100    }
101
102    pub fn http(&self) -> crate::http::HttpClient {
103        self.http_client.with_timeout(self.http_timeout)
104    }
105
106    pub fn raw_http(&self) -> &reqwest::Client {
107        self.http_client.inner()
108    }
109
110    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
111        self.http_timeout = timeout;
112    }
113
114    /// Dispatch a background job.
115    pub async fn dispatch_job<T: serde::Serialize>(
116        &self,
117        job_type: &str,
118        args: T,
119    ) -> crate::Result<Uuid> {
120        let dispatcher = self
121            .job_dispatch
122            .as_ref()
123            .ok_or_else(|| crate::error::ForgeError::internal("Job dispatch not available"))?;
124
125        let args_json = serde_json::to_value(args)?;
126        dispatcher
127            .dispatch_by_name(job_type, args_json, None, None)
128            .await
129    }
130
131    /// Type-safe dispatch: resolves the job name from the type's `ForgeJob`
132    /// impl and serializes the args at the call site.
133    pub async fn dispatch<J: crate::ForgeJob>(&self, args: J::Args) -> crate::Result<Uuid> {
134        self.dispatch_job(J::info().name, args).await
135    }
136
137    /// Request cancellation for a job.
138    pub async fn cancel_job(&self, job_id: Uuid, reason: Option<String>) -> crate::Result<bool> {
139        let dispatcher = self
140            .job_dispatch
141            .as_ref()
142            .ok_or_else(|| crate::error::ForgeError::internal("Job dispatch not available"))?;
143        dispatcher.cancel(job_id, reason).await
144    }
145
146    /// Start a workflow.
147    pub async fn start_workflow<T: serde::Serialize>(
148        &self,
149        workflow_name: &str,
150        input: T,
151    ) -> crate::Result<Uuid> {
152        let dispatcher = self
153            .workflow_dispatch
154            .as_ref()
155            .ok_or_else(|| crate::error::ForgeError::internal("Workflow dispatch not available"))?;
156
157        let input_json = serde_json::to_value(input)?;
158        dispatcher
159            .start_by_name(workflow_name, input_json, None, None)
160            .await
161    }
162
163    /// Type-safe workflow start.
164    pub async fn start<W: crate::ForgeWorkflow>(&self, input: W::Input) -> crate::Result<Uuid> {
165        self.start_workflow(W::info().name, input).await
166    }
167
168    /// Check if shutdown has been requested.
169    pub fn is_shutdown_requested(&self) -> bool {
170        // Use try_lock to avoid blocking; if can't lock, assume not shutdown
171        self.shutdown_rx
172            .try_lock()
173            .map(|rx| *rx.borrow())
174            .unwrap_or(false)
175    }
176
177    /// Wait for shutdown signal.
178    ///
179    /// Use this in a `tokio::select!` to handle graceful shutdown:
180    ///
181    /// ```ignore
182    /// tokio::select! {
183    ///     _ = tokio::time::sleep(Duration::from_secs(60)) => {}
184    ///     _ = ctx.shutdown_signal() => break,
185    /// }
186    /// ```
187    pub async fn shutdown_signal(&self) {
188        let mut rx = self.shutdown_rx.lock().await;
189        // Wait until the value becomes true
190        while !*rx.borrow_and_update() {
191            if rx.changed().await.is_err() {
192                // Channel closed, treat as shutdown
193                break;
194            }
195        }
196    }
197
198    /// Sleep for `interval`, waking early if shutdown is requested.
199    ///
200    /// Returns `true` if the daemon should continue, `false` if shutdown was
201    /// requested before or during the sleep. Intended for the main daemon loop:
202    ///
203    /// ```ignore
204    /// while ctx.tick(Duration::from_secs(60)).await {
205    ///     // do periodic work
206    /// }
207    /// ```
208    pub async fn tick(&self, interval: Duration) -> bool {
209        tokio::select! {
210            _ = tokio::time::sleep(interval) => true,
211            _ = self.shutdown_signal() => false,
212        }
213    }
214
215    /// Send heartbeat to indicate daemon is alive.
216    pub async fn heartbeat(&self) -> crate::Result<()> {
217        tracing::trace!(daemon.name = %self.daemon_name, "Sending heartbeat");
218
219        sqlx::query!(
220            r#"
221            UPDATE forge_daemons
222            SET last_heartbeat = NOW()
223            WHERE name = $1 AND instance_id = $2
224            "#,
225            self.daemon_name,
226            self.instance_id,
227        )
228        .execute(&self.db_pool)
229        .await
230        .map_err(crate::ForgeError::Database)?;
231
232        Ok(())
233    }
234
235    /// Get the trace ID for this daemon execution.
236    ///
237    /// Returns the instance_id as a correlation ID.
238    pub fn trace_id(&self) -> String {
239        self.instance_id.to_string()
240    }
241
242    /// Get the parent span for trace propagation.
243    ///
244    /// Use this to create child spans within daemon handlers.
245    pub fn span(&self) -> &Span {
246        &self.span
247    }
248}
249
250impl EnvAccess for DaemonContext {
251    fn env_provider(&self) -> &dyn EnvProvider {
252        self.env_provider.as_ref()
253    }
254}
255
256#[cfg(test)]
257#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
258mod tests {
259    use super::*;
260    use crate::env::MockEnvProvider;
261    use crate::error::ForgeError;
262
263    fn lazy_ctx() -> (DaemonContext, watch::Sender<bool>, Uuid) {
264        let pool = sqlx::postgres::PgPoolOptions::new()
265            .max_connections(1)
266            .connect_lazy("postgres://localhost/nonexistent")
267            .expect("Failed to create mock pool");
268        let (shutdown_tx, shutdown_rx) = watch::channel(false);
269        let instance_id = Uuid::new_v4();
270        let ctx = DaemonContext::new(
271            "test_daemon".to_string(),
272            instance_id,
273            pool,
274            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
275            shutdown_rx,
276        );
277        (ctx, shutdown_tx, instance_id)
278    }
279
280    #[tokio::test]
281    async fn test_daemon_context_creation() {
282        let (ctx, shutdown_tx, instance_id) = lazy_ctx();
283
284        assert_eq!(ctx.daemon_name, "test_daemon");
285        assert_eq!(ctx.instance_id, instance_id);
286        assert!(!ctx.is_shutdown_requested());
287
288        // Signal shutdown
289        shutdown_tx.send(true).unwrap();
290        assert!(ctx.is_shutdown_requested());
291    }
292
293    #[tokio::test]
294    async fn test_shutdown_signal() {
295        let (ctx, shutdown_tx, _) = lazy_ctx();
296
297        // Spawn a task to signal shutdown after a delay
298        tokio::spawn(async move {
299            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
300            shutdown_tx.send(true).unwrap();
301        });
302
303        // Wait for shutdown signal
304        tokio::time::timeout(std::time::Duration::from_millis(200), ctx.shutdown_signal())
305            .await
306            .expect("Shutdown signal should complete");
307
308        assert!(ctx.is_shutdown_requested());
309    }
310
311    #[tokio::test]
312    async fn dispatch_job_returns_internal_when_dispatcher_unset() {
313        let (ctx, _tx, _id) = lazy_ctx();
314        let err = ctx
315            .dispatch_job("send_email", serde_json::json!({"to": "x"}))
316            .await
317            .unwrap_err();
318        match err {
319            ForgeError::Internal { context: msg, .. } => assert!(msg.contains("Job dispatch")),
320            other => panic!("expected Internal error, got {other:?}"),
321        }
322    }
323
324    #[tokio::test]
325    async fn start_workflow_returns_internal_when_dispatcher_unset() {
326        let (ctx, _tx, _id) = lazy_ctx();
327        let err = ctx
328            .start_workflow("checkout", serde_json::json!({"cart": 1}))
329            .await
330            .unwrap_err();
331        match err {
332            ForgeError::Internal { context: msg, .. } => assert!(msg.contains("Workflow dispatch")),
333            other => panic!("expected Internal error, got {other:?}"),
334        }
335    }
336
337    #[tokio::test]
338    async fn trace_id_returns_instance_id_as_string() {
339        let (ctx, _tx, instance_id) = lazy_ctx();
340        assert_eq!(ctx.trace_id(), instance_id.to_string());
341    }
342
343    #[tokio::test]
344    async fn set_http_timeout_round_trips_through_http_client() {
345        let (mut ctx, _tx, _id) = lazy_ctx();
346        // Default is unbounded; setting then clearing must be observable via http().
347        ctx.set_http_timeout(Some(Duration::from_millis(250)));
348        // The HttpClient is opaque, but the call must not panic and the
349        // setter must be idempotent on repeat — confirm by re-setting None.
350        let _client = ctx.http();
351        ctx.set_http_timeout(None);
352        let _client = ctx.http();
353    }
354
355    #[tokio::test]
356    async fn with_env_provider_overrides_real_provider() {
357        let (ctx, _tx, _id) = lazy_ctx();
358        let mut mock = MockEnvProvider::new();
359        mock.set("FORGE_TEST_KEY", "hello");
360        let ctx = ctx.with_env_provider(Arc::new(mock));
361        // EnvAccess is implemented for DaemonContext; route through the trait
362        // method to prove the override took effect end-to-end.
363        use crate::env::EnvAccess;
364        assert_eq!(ctx.env("FORGE_TEST_KEY"), Some("hello".to_string()));
365        assert_eq!(ctx.env("FORGE_MISSING_KEY"), None);
366    }
367
368    #[tokio::test]
369    async fn tick_returns_true_after_interval_elapses() {
370        let (ctx, _tx, _) = lazy_ctx();
371        // Short interval; no shutdown fired — must return true.
372        let should_continue = ctx.tick(Duration::from_millis(10)).await;
373        assert!(should_continue);
374    }
375
376    #[tokio::test]
377    async fn tick_returns_false_when_shutdown_fires_before_interval() {
378        let (ctx, shutdown_tx, _) = lazy_ctx();
379        // Signal shutdown immediately before the long interval would finish.
380        shutdown_tx.send(true).unwrap();
381        // Interval is very long; shutdown should preempt it and return false quickly.
382        let should_continue = tokio::time::timeout(
383            Duration::from_millis(200),
384            ctx.tick(Duration::from_secs(60)),
385        )
386        .await
387        .expect("tick should return promptly on shutdown");
388        assert!(!should_continue);
389    }
390
391    #[tokio::test]
392    async fn span_returns_current_span_handle() {
393        let (ctx, _tx, _id) = lazy_ctx();
394        // We can't introspect the span content, but it must be a real handle
395        // (not disabled) so child spans can attach. `is_disabled` is the only
396        // observable cheap check that doesn't require a subscriber.
397        let _ = ctx.span().id();
398    }
399}