Skip to main content

forge_core/daemon/
context.rs

1use std::sync::Arc;
2
3use tokio::sync::{Mutex, watch};
4use tracing::Span;
5use uuid::Uuid;
6
7use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
8use crate::function::{JobDispatch, WorkflowDispatch};
9
10/// Context available to daemon handlers.
11pub struct DaemonContext {
12    /// Daemon name.
13    pub daemon_name: String,
14    /// Unique instance ID for this daemon execution.
15    pub instance_id: Uuid,
16    /// Database pool.
17    db_pool: sqlx::PgPool,
18    /// HTTP client for external calls.
19    http_client: reqwest::Client,
20    /// Shutdown signal receiver (wrapped in Mutex for interior mutability).
21    shutdown_rx: Mutex<watch::Receiver<bool>>,
22    /// Job dispatcher for background jobs.
23    job_dispatch: Option<Arc<dyn JobDispatch>>,
24    /// Workflow dispatcher for starting workflows.
25    workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
26    /// Environment variable provider.
27    env_provider: Arc<dyn EnvProvider>,
28    /// Parent span for trace propagation.
29    span: Span,
30}
31
32impl DaemonContext {
33    /// Create a new daemon context.
34    pub fn new(
35        daemon_name: String,
36        instance_id: Uuid,
37        db_pool: sqlx::PgPool,
38        http_client: reqwest::Client,
39        shutdown_rx: watch::Receiver<bool>,
40    ) -> Self {
41        Self {
42            daemon_name,
43            instance_id,
44            db_pool,
45            http_client,
46            shutdown_rx: Mutex::new(shutdown_rx),
47            job_dispatch: None,
48            workflow_dispatch: None,
49            env_provider: Arc::new(RealEnvProvider::new()),
50            span: Span::current(),
51        }
52    }
53
54    /// Set job dispatcher.
55    pub fn with_job_dispatch(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
56        self.job_dispatch = Some(dispatcher);
57        self
58    }
59
60    /// Set workflow dispatcher.
61    pub fn with_workflow_dispatch(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
62        self.workflow_dispatch = Some(dispatcher);
63        self
64    }
65
66    /// Set environment provider.
67    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
68        self.env_provider = provider;
69        self
70    }
71
72    pub fn db(&self) -> &sqlx::PgPool {
73        &self.db_pool
74    }
75
76    /// Returns a `DbConn` wrapping the pool for shared helper functions.
77    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
78        crate::function::DbConn::Pool(&self.db_pool)
79    }
80
81    /// Acquire a connection compatible with sqlx compile-time checked macros.
82    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
83        Ok(crate::function::ForgeConn::Pool(
84            self.db_pool.acquire().await?,
85        ))
86    }
87
88    pub fn http(&self) -> &reqwest::Client {
89        &self.http_client
90    }
91
92    /// Dispatch a background job.
93    pub async fn dispatch_job<T: serde::Serialize>(
94        &self,
95        job_type: &str,
96        args: T,
97    ) -> crate::Result<Uuid> {
98        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
99            crate::error::ForgeError::Internal("Job dispatch not available".to_string())
100        })?;
101
102        let args_json = serde_json::to_value(args)?;
103        dispatcher.dispatch_by_name(job_type, args_json, None).await
104    }
105
106    /// Start a workflow.
107    pub async fn start_workflow<T: serde::Serialize>(
108        &self,
109        workflow_name: &str,
110        input: T,
111    ) -> crate::Result<Uuid> {
112        let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
113            crate::error::ForgeError::Internal("Workflow dispatch not available".to_string())
114        })?;
115
116        let input_json = serde_json::to_value(input)?;
117        dispatcher
118            .start_by_name(workflow_name, input_json, None)
119            .await
120    }
121
122    /// Check if shutdown has been requested.
123    pub fn is_shutdown_requested(&self) -> bool {
124        // Use try_lock to avoid blocking; if can't lock, assume not shutdown
125        self.shutdown_rx
126            .try_lock()
127            .map(|rx| *rx.borrow())
128            .unwrap_or(false)
129    }
130
131    /// Wait for shutdown signal.
132    ///
133    /// Use this in a `tokio::select!` to handle graceful shutdown:
134    ///
135    /// ```ignore
136    /// tokio::select! {
137    ///     _ = tokio::time::sleep(Duration::from_secs(60)) => {}
138    ///     _ = ctx.shutdown_signal() => break,
139    /// }
140    /// ```
141    pub async fn shutdown_signal(&self) {
142        let mut rx = self.shutdown_rx.lock().await;
143        // Wait until the value becomes true
144        while !*rx.borrow_and_update() {
145            if rx.changed().await.is_err() {
146                // Channel closed, treat as shutdown
147                break;
148            }
149        }
150    }
151
152    /// Send heartbeat to indicate daemon is alive.
153    pub async fn heartbeat(&self) -> crate::Result<()> {
154        tracing::trace!(daemon.name = %self.daemon_name, "Sending heartbeat");
155
156        sqlx::query(
157            r#"
158            UPDATE forge_daemons
159            SET last_heartbeat = NOW()
160            WHERE name = $1 AND instance_id = $2
161            "#,
162        )
163        .bind(&self.daemon_name)
164        .bind(self.instance_id)
165        .execute(&self.db_pool)
166        .await
167        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
168
169        Ok(())
170    }
171
172    /// Get the trace ID for this daemon execution.
173    ///
174    /// Returns the instance_id as a correlation ID.
175    pub fn trace_id(&self) -> String {
176        self.instance_id.to_string()
177    }
178
179    /// Get the parent span for trace propagation.
180    ///
181    /// Use this to create child spans within daemon handlers.
182    pub fn span(&self) -> &Span {
183        &self.span
184    }
185}
186
187impl EnvAccess for DaemonContext {
188    fn env_provider(&self) -> &dyn EnvProvider {
189        self.env_provider.as_ref()
190    }
191}
192
193#[cfg(test)]
194#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
195mod tests {
196    use super::*;
197
198    #[tokio::test]
199    async fn test_daemon_context_creation() {
200        let pool = sqlx::postgres::PgPoolOptions::new()
201            .max_connections(1)
202            .connect_lazy("postgres://localhost/nonexistent")
203            .expect("Failed to create mock pool");
204
205        let (shutdown_tx, shutdown_rx) = watch::channel(false);
206        let instance_id = Uuid::new_v4();
207
208        let ctx = DaemonContext::new(
209            "test_daemon".to_string(),
210            instance_id,
211            pool,
212            reqwest::Client::new(),
213            shutdown_rx,
214        );
215
216        assert_eq!(ctx.daemon_name, "test_daemon");
217        assert_eq!(ctx.instance_id, instance_id);
218        assert!(!ctx.is_shutdown_requested());
219
220        // Signal shutdown
221        shutdown_tx.send(true).unwrap();
222        assert!(ctx.is_shutdown_requested());
223    }
224
225    #[tokio::test]
226    async fn test_shutdown_signal() {
227        let pool = sqlx::postgres::PgPoolOptions::new()
228            .max_connections(1)
229            .connect_lazy("postgres://localhost/nonexistent")
230            .expect("Failed to create mock pool");
231
232        let (shutdown_tx, shutdown_rx) = watch::channel(false);
233
234        let ctx = DaemonContext::new(
235            "test_daemon".to_string(),
236            Uuid::new_v4(),
237            pool,
238            reqwest::Client::new(),
239            shutdown_rx,
240        );
241
242        // Spawn a task to signal shutdown after a delay
243        tokio::spawn(async move {
244            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
245            shutdown_tx.send(true).unwrap();
246        });
247
248        // Wait for shutdown signal
249        tokio::time::timeout(std::time::Duration::from_millis(200), ctx.shutdown_signal())
250            .await
251            .expect("Shutdown signal should complete");
252
253        assert!(ctx.is_shutdown_requested());
254    }
255}