Skip to main content

forge_core/daemon/
context.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::{Mutex, watch};
5use tracing::Span;
6use uuid::Uuid;
7
8use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
9use crate::function::{JobDispatch, WorkflowDispatch};
10use crate::http::CircuitBreakerClient;
11
12/// Context available to daemon handlers.
13pub struct DaemonContext {
14    /// Daemon name.
15    pub daemon_name: String,
16    /// Unique instance ID for this daemon execution.
17    pub instance_id: Uuid,
18    /// Database pool.
19    db_pool: sqlx::PgPool,
20    /// HTTP client for external calls.
21    http_client: CircuitBreakerClient,
22    /// Default timeout for outbound HTTP requests made through the
23    /// circuit-breaker client. `None` means unlimited.
24    http_timeout: Option<Duration>,
25    /// Shutdown signal receiver (wrapped in Mutex for interior mutability).
26    shutdown_rx: Mutex<watch::Receiver<bool>>,
27    /// Job dispatcher for background jobs.
28    job_dispatch: Option<Arc<dyn JobDispatch>>,
29    /// Workflow dispatcher for starting workflows.
30    workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
31    /// Environment variable provider.
32    env_provider: Arc<dyn EnvProvider>,
33    /// Parent span for trace propagation.
34    span: Span,
35}
36
37impl DaemonContext {
38    /// Create a new daemon context.
39    pub fn new(
40        daemon_name: String,
41        instance_id: Uuid,
42        db_pool: sqlx::PgPool,
43        http_client: CircuitBreakerClient,
44        shutdown_rx: watch::Receiver<bool>,
45    ) -> Self {
46        Self {
47            daemon_name,
48            instance_id,
49            db_pool,
50            http_client,
51            http_timeout: None,
52            shutdown_rx: Mutex::new(shutdown_rx),
53            job_dispatch: None,
54            workflow_dispatch: None,
55            env_provider: Arc::new(RealEnvProvider::new()),
56            span: Span::current(),
57        }
58    }
59
60    /// Set job dispatcher.
61    pub fn with_job_dispatch(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
62        self.job_dispatch = Some(dispatcher);
63        self
64    }
65
66    /// Set workflow dispatcher.
67    pub fn with_workflow_dispatch(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
68        self.workflow_dispatch = Some(dispatcher);
69        self
70    }
71
72    /// Set environment provider.
73    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
74        self.env_provider = provider;
75        self
76    }
77
78    pub fn db(&self) -> crate::function::ForgeDb {
79        crate::function::ForgeDb::from_pool(&self.db_pool)
80    }
81
82    /// Get a `DbConn` for use in shared helper functions.
83    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
84        crate::function::DbConn::Pool(self.db_pool.clone())
85    }
86
87    /// Acquire a connection compatible with sqlx compile-time checked macros.
88    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
89        Ok(crate::function::ForgeConn::Pool(
90            self.db_pool.acquire().await?,
91        ))
92    }
93
94    pub fn http(&self) -> crate::http::HttpClient {
95        self.http_client.with_timeout(self.http_timeout)
96    }
97
98    pub fn raw_http(&self) -> &reqwest::Client {
99        self.http_client.inner()
100    }
101
102    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
103        self.http_timeout = timeout;
104    }
105
106    /// Dispatch a background job.
107    pub async fn dispatch_job<T: serde::Serialize>(
108        &self,
109        job_type: &str,
110        args: T,
111    ) -> crate::Result<Uuid> {
112        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
113            crate::error::ForgeError::Internal("Job dispatch not available".to_string())
114        })?;
115
116        let args_json = serde_json::to_value(args)?;
117        dispatcher.dispatch_by_name(job_type, args_json, None).await
118    }
119
120    /// Start a workflow.
121    pub async fn start_workflow<T: serde::Serialize>(
122        &self,
123        workflow_name: &str,
124        input: T,
125    ) -> crate::Result<Uuid> {
126        let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
127            crate::error::ForgeError::Internal("Workflow dispatch not available".to_string())
128        })?;
129
130        let input_json = serde_json::to_value(input)?;
131        dispatcher
132            .start_by_name(workflow_name, input_json, None)
133            .await
134    }
135
136    /// Check if shutdown has been requested.
137    pub fn is_shutdown_requested(&self) -> bool {
138        // Use try_lock to avoid blocking; if can't lock, assume not shutdown
139        self.shutdown_rx
140            .try_lock()
141            .map(|rx| *rx.borrow())
142            .unwrap_or(false)
143    }
144
145    /// Wait for shutdown signal.
146    ///
147    /// Use this in a `tokio::select!` to handle graceful shutdown:
148    ///
149    /// ```ignore
150    /// tokio::select! {
151    ///     _ = tokio::time::sleep(Duration::from_secs(60)) => {}
152    ///     _ = ctx.shutdown_signal() => break,
153    /// }
154    /// ```
155    pub async fn shutdown_signal(&self) {
156        let mut rx = self.shutdown_rx.lock().await;
157        // Wait until the value becomes true
158        while !*rx.borrow_and_update() {
159            if rx.changed().await.is_err() {
160                // Channel closed, treat as shutdown
161                break;
162            }
163        }
164    }
165
166    /// Send heartbeat to indicate daemon is alive.
167    pub async fn heartbeat(&self) -> crate::Result<()> {
168        tracing::trace!(daemon.name = %self.daemon_name, "Sending heartbeat");
169
170        sqlx::query!(
171            r#"
172            UPDATE forge_daemons
173            SET last_heartbeat = NOW()
174            WHERE name = $1 AND instance_id = $2
175            "#,
176            self.daemon_name,
177            self.instance_id,
178        )
179        .execute(&self.db_pool)
180        .await
181        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
182
183        Ok(())
184    }
185
186    /// Get the trace ID for this daemon execution.
187    ///
188    /// Returns the instance_id as a correlation ID.
189    pub fn trace_id(&self) -> String {
190        self.instance_id.to_string()
191    }
192
193    /// Get the parent span for trace propagation.
194    ///
195    /// Use this to create child spans within daemon handlers.
196    pub fn span(&self) -> &Span {
197        &self.span
198    }
199}
200
201impl EnvAccess for DaemonContext {
202    fn env_provider(&self) -> &dyn EnvProvider {
203        self.env_provider.as_ref()
204    }
205}
206
207#[cfg(test)]
208#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
209mod tests {
210    use super::*;
211
212    #[tokio::test]
213    async fn test_daemon_context_creation() {
214        let pool = sqlx::postgres::PgPoolOptions::new()
215            .max_connections(1)
216            .connect_lazy("postgres://localhost/nonexistent")
217            .expect("Failed to create mock pool");
218
219        let (shutdown_tx, shutdown_rx) = watch::channel(false);
220        let instance_id = Uuid::new_v4();
221
222        let ctx = DaemonContext::new(
223            "test_daemon".to_string(),
224            instance_id,
225            pool,
226            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
227            shutdown_rx,
228        );
229
230        assert_eq!(ctx.daemon_name, "test_daemon");
231        assert_eq!(ctx.instance_id, instance_id);
232        assert!(!ctx.is_shutdown_requested());
233
234        // Signal shutdown
235        shutdown_tx.send(true).unwrap();
236        assert!(ctx.is_shutdown_requested());
237    }
238
239    #[tokio::test]
240    async fn test_shutdown_signal() {
241        let pool = sqlx::postgres::PgPoolOptions::new()
242            .max_connections(1)
243            .connect_lazy("postgres://localhost/nonexistent")
244            .expect("Failed to create mock pool");
245
246        let (shutdown_tx, shutdown_rx) = watch::channel(false);
247
248        let ctx = DaemonContext::new(
249            "test_daemon".to_string(),
250            Uuid::new_v4(),
251            pool,
252            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
253            shutdown_rx,
254        );
255
256        // Spawn a task to signal shutdown after a delay
257        tokio::spawn(async move {
258            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
259            shutdown_tx.send(true).unwrap();
260        });
261
262        // Wait for shutdown signal
263        tokio::time::timeout(std::time::Duration::from_millis(200), ctx.shutdown_signal())
264            .await
265            .expect("Shutdown signal should complete");
266
267        assert!(ctx.is_shutdown_requested());
268    }
269}